import copy

import torch
import subprocess
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim

from src.matrix_completion.w_layer import WLayer
from src.matrix_completion.ab_layer import ABLayer
from src.matrix_completion.svd_layer import SVDLayer
from src.matrix_completion.aug_bug_layer import AugBUGLayer
from src.matrix_completion.adalora_layer import AdaloraLayer

# from src.matrix_completion.aug_bug_diagS_layer import AugBUGDiagSLayer
from src.matrix_completion.parallel_low_rank_layer import ParallelLowRankLayer


import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import math
import csv
import time


def get_available_device():
    # Get GPU memory usage using nvidia-smi
    cmd = "nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits"
    memory_used = subprocess.check_output(cmd.split()).decode().strip().split("\n")
    memory_used = [int(memory.strip()) for memory in memory_used]

    # Find GPU with least memory usage
    device = memory_used.index(min(memory_used))
    return torch.device(f"cuda:{device}")


def create_lr_matrix(size_in, size_out, rank):
    U, _ = torch.linalg.qr(torch.randn(size_in, rank), "reduced")
    V, _ = torch.linalg.qr(torch.randn(size_out, rank), "reduced")
    sigmas, _ = torch.sort(torch.randn(rank) ** 2, descending=True)
    sigmas = torch.Tensor([2.3, 2, 1.7, 1.2, 0.7])
    S = torch.diag(sigmas)

    W = U @ S @ V.T
    P, d, Q = torch.linalg.svd(W)
    W2 = P @ torch.diag(d) @ Q
    print(torch.norm(W - W2, p="fro") / torch.norm(W, p="fro"))

    return U @ S @ V.T


def main():

    device = torch.device(
        get_available_device() if torch.cuda.is_available() else "cpu"
    )

    n_iter = 10000
    step_size = 1e-3
    matrix_size = 2000
    rank = 5
    init_rank = 40
    tau = 0.01
    low_rank_approach = 1
    lora = True
    if low_rank_approach == 0:  # Full rank
        ansatz = WLayer(matrix_size, matrix_size).to(device)  # full rank
    elif low_rank_approach == 1:  # BUG
        ansatz = AugBUGLayer(
            matrix_size, matrix_size, rank=init_rank, tau=tau, bias=False
        ).to(device)
    elif low_rank_approach == 2:  # Parallel
        ansatz = ParallelLowRankLayer(
            matrix_size, matrix_size, rank=init_rank, tau=tau, bias=False
        ).to(device)
    else:
        ansatz = ABLayer(matrix_size, matrix_size, rank=init_rank, bias=True).to(device)

    if lora:
        reference = create_lr_matrix(matrix_size, matrix_size, rank).to(device)

        matrix_completion(reference, ansatz, n_iter, step_size, low_rank_approach)

    else:
        reference = torch.randn(matrix_size, matrix_size).to(device)
        lora_W = reference + create_lr_matrix(matrix_size, matrix_size, rank).to(device)

        loss_values = matrix_completion(
            reference, lora_W, ansatz, n_iter, step_size, low_rank_approach
        )

    return 0


def main_run_all():

    device = torch.device(
        get_available_device() if torch.cuda.is_available() else "cpu"
    )

    n_iter = 1000
    step_size = 0.1
    matrix_size = 5000
    rank = 5
    init_rank = 50
    coeff_steps = 1
    tau = 0.15
    low_rank_approach = 1
    lora = False
    all_loss_values = []
    print("create ansatz")
    ans = ParallelLowRankLayer(
        matrix_size, matrix_size, rank=init_rank, tau=tau, bias=False
    )
    with torch.no_grad():
        ans.S[:init_rank, :init_rank] *= 0
        ans.Sinv[:init_rank, :init_rank] = ans.Sinv[
            :init_rank, :init_rank
        ] * 0 + torch.diag(torch.ones(init_rank))

    print("create reference")

    if lora:
        reference = create_lr_matrix(matrix_size, matrix_size, rank).to(
            device
        ) + torch.randn(matrix_size, matrix_size).to(device)
        # THis packs the frozen, "pretrained" W matrix to the reference, since it is not trainable
    else:
        # refernece is low-rank
        reference = create_lr_matrix(matrix_size, matrix_size, rank).to(device)
    # for low_rank_approach in [4, 2]:  # [1, 2, 4, 0, 3]:
    #    print("create ansatz")

    for low_rank_approach in [0, 1, 2, 3, 4, 5]:
        if low_rank_approach == 0:  # Full rank
            t_ans = copy.deepcopy(ans)
            ansatz = WLayer(
                W_init=t_ans.U[:, :init_rank]
                @ t_ans.S[:init_rank, :init_rank]
                @ t_ans.VT[:init_rank, :]
            ).to(
                device
            )  # full rank
        elif low_rank_approach == 1:  # BUG
            t_ans = copy.deepcopy(ans)

            ansatz = AugBUGLayer(
                U=t_ans.U,
                S=t_ans.S,
                V=t_ans.VT.T,
                rank=init_rank,
                tau=tau,
                bias=False,
                output_size=matrix_size,
            ).to(device)
        elif low_rank_approach == 2:  # Parallel
            t_ans = copy.deepcopy(ans)

            ansatz = t_ans.to(device)
        elif low_rank_approach == 3:
            t_ans = copy.deepcopy(ans)

            ansatz = ABLayer(
                A=t_ans.U[:, :init_rank],
                BT=t_ans.S[:init_rank, :init_rank] @ t_ans.VT[:init_rank, :],
                input_size=matrix_size,
                output_size=matrix_size,
                rank=init_rank,
            ).to(device)
            # step_size = 1e-1
        elif low_rank_approach == 4:
            t_ans = copy.deepcopy(ans)

            ansatz = SVDLayer(
                A=t_ans.U,
                S=t_ans.S,
                BT=t_ans.VT,
                output_size=matrix_size,
                rank=init_rank,
            ).to(device)
        elif low_rank_approach == 5:
            t_ans = copy.deepcopy(ans)

            ansatz = AdaloraLayer(
                U=t_ans.U,
                S=t_ans.S,
                VT=t_ans.VT,
                output_size=matrix_size,
                rank=init_rank,
                tau=tau,
            ).to(device)
            # step_size = 1e-1

        initial_cond = ansatz()
        print("initial cond", initial_cond)

        loss_values = matrix_completion(
            reference,
            ansatz,
            n_iter,
            step_size,
            low_rank_approach,
            coeff_steps=coeff_steps,
        )
        all_loss_values.append(loss_values)

    save_loss_lists_to_file(all_loss_values, "loss_lists.csv")
    return 0


def plot_loss_graphs():
    """
    Plots absolute and relative loss graphs for multiple training setups using Seaborn.

    Parameters:
    loss_lists (list of lists): A list containing sublists, each representing tuples of absolute and relative loss values for a training setup.

    """
    loss_lists = read_loss_lists_from_file("loss_lists.csv")

    # Prepare the data for plotting
    data_absolute = []
    data_relative = []
    data_ranks = []
    data_time = []

    for i, losses in enumerate(loss_lists):
        for epoch, (abs_loss, rel_loss, rank, time) in enumerate(losses):

            if i == 0:
                setup = "Full FT"
            elif i == 1:
                setup = "Aug BUG FT"
            elif i == 2:
                setup = "Parallel BUG FT"
            elif i == 3:
                setup = "LoRA FT"
            elif i == 4:
                setup = "SVD LoRA FT"
            elif i == 5:
                setup = "AdaLoRa FT"
            else:
                exit("Invalid index")
                setup = "Simulataneous descend"

            data_absolute.append(
                {"Iteration": epoch, "Error": float(abs_loss), "Optimizer": setup}
            )
            data_relative.append(
                {"Iteration": epoch, "Error": float(rel_loss), "Optimizer": setup}
            )
            if i in [1, 2, 3, 4, 5]:
                data_ranks.append(
                    {"Iteration": epoch, "Rank": float(rank), "Optimizer": setup}
                )
            data_time.append(
                {"Iteration": epoch, "Time": float(time), "Optimizer": setup}
            )

    # Convert to DataFrames
    df_absolute = pd.DataFrame(data_absolute)
    df_relative = pd.DataFrame(data_relative)
    df_rank = pd.DataFrame(data_ranks)
    df_time = pd.DataFrame(data_time)

    # Set the style for a paper
    sns.set(style="whitegrid")
    sns.set_context("paper", font_scale=2.2)

    ############ ABSOLUTE LOSS ############
    # Plot absolute loss using Seaborn
    plt.figure(figsize=(10, 6))
    ax_abs = sns.lineplot(
        data=df_absolute, x="Iteration", y="Error", hue="Optimizer", linewidth=2.5
    )

    # plt.title("Training Absolute Loss for Different Setups")
    plt.xlabel("Iteration")
    plt.ylabel(r"$||W-W_{\text{ans}}||_{F}$")
    plt.legend(title="Training Setup", loc="best")
    plt.grid(True, linestyle="--", linewidth=0.7)
    ax_abs.set_ylim([1e-6, 15])
    # Set logarithmic scale
    ax_abs.set_xscale("log")
    ax_abs.set_yscale("log")
    # Remove top and right spines for a cleaner look
    sns.despine()

    # Save the plot as a high-resolution PNG file
    plt.tight_layout()
    plt.savefig(
        "img_matrix_completion/matrix_completion_absolute_loss_graphs.png", dpi=300
    )
    plt.savefig(
        "img_matrix_completion/matrix_completion_absolute_loss_graphs.pdf", dpi=300
    )
    ax_abs.set_yscale("linear")
    ax_abs.set_xscale("linear")

    plt.savefig(
        "img_matrix_completion/matrix_completion_absolute_loss_graphs_liny.png", dpi=300
    )
    plt.savefig(
        "img_matrix_completion/matrix_completion_absolute_loss_graphs_liny.pdf", dpi=300
    )

    ############ RELATIVE LOSS ############

    # Plot relative loss using Seaborn
    plt.figure(figsize=(10, 6))
    ax_rel = sns.lineplot(
        data=df_relative, x="Iteration", y="Error", hue="Optimizer", linewidth=2.5
    )

    # plt.title("Training Relative Loss for Different Setups")
    plt.xlabel("Iteration")
    plt.ylabel("Relative Error")
    plt.legend(title="Training Setup", loc="best")
    plt.grid(True, linestyle="--", linewidth=0.7)
    ax_rel.set_ylim([5e-4, 1e0])

    # Set logarithmic scale
    ax_rel.set_xscale("log")
    ax_rel.set_yscale("log")
    # Remove top and right spines for a cleaner look
    sns.despine()

    # Save the plot as a high-resolution PNG file
    plt.tight_layout()
    plt.savefig(
        "img_matrix_completion/matrix_completion_relative_loss_graphs.png", dpi=300
    )
    plt.savefig(
        "img_matrix_completion/matrix_completion_relative_loss_graphs.pdf", dpi=300
    )

    ############ TIME ############

    # Plot ranks using Seaborn
    plt.figure(figsize=(10, 6))
    ax_rel = sns.lineplot(
        data=df_time, x="Iteration", y="Time", hue="Optimizer", linewidth=2.5
    )

    # plt.title("Training Relative Loss for Different Setups")
    plt.xlabel("Iteration")
    plt.ylabel("Time [s]")
    plt.legend(title="Training Setup", loc="best")
    plt.grid(True, linestyle="--", linewidth=0.7)
    # Set logarithmic scale
    # ax_rel.set_xscale("log")
    ax_rel.set_yscale("log")
    # Remove top and right spines for a cleaner look
    sns.despine()

    # Save the plot as a high-resolution PNG file
    plt.tight_layout()
    plt.savefig("img_matrix_completion/matrix_completion_time_graphs.png", dpi=300)
    plt.savefig("img_matrix_completion/matrix_completion_time_graphs.pdf", dpi=300)

    ############ RANK ############
    # Plot ranks using Seaborn
    plt.figure(figsize=(10, 6))
    ax_rel = sns.lineplot(
        data=df_rank, x="Iteration", y="Rank", hue="Optimizer", linewidth=2.5
    )

    # plt.title("Training Relative Loss for Different Setups")
    plt.xlabel("Iteration")
    plt.ylabel("Rank")
    plt.legend(title="Training Setup", loc="best")
    plt.grid(True, linestyle="--", linewidth=0.7)
    # Set logarithmic scale
    ax_rel.set_xscale("log")
    ax_rel.set_yscale("log")
    # Remove top and right spines for a cleaner look
    sns.despine()

    # Save the plot as a high-resolution PNG file
    plt.tight_layout()
    plt.savefig("img_matrix_completion/matrix_completion_rank_graphs.png", dpi=300)
    plt.savefig("img_matrix_completion/matrix_completion_rank_graphs.pdf", dpi=300)


def save_loss_lists_to_file(loss_lists, filename):
    """
    Saves a list of lists containing loss values to a CSV file.

    Parameters:
    loss_lists (list of lists): A list containing sublists, each representing tuples of absolute and relative loss values for a training setup.
    filename (str): The name of the file to save the data.
    """
    with open(filename, mode="w", newline="") as file:
        writer = csv.writer(file)
        for loss_list in loss_lists:
            writer.writerow(
                [
                    f"{abs_loss},{rel_loss},{rank}, {time}"
                    for abs_loss, rel_loss, rank, time in loss_list
                ]
            )


def read_loss_lists_from_file(filename):
    """
    Reads a list of lists containing loss values from a CSV file.

    Parameters:
    filename (str): The name of the file to read the data from.

    Returns:
    list of lists: A list containing sublists, each representing tuples of absolute and relative loss values for a training setup.
    """
    loss_lists = []
    with open(filename, mode="r") as file:
        reader = csv.reader(file)
        for row in reader:
            loss_list = [
                (float(abs_loss), float(rel_loss), float(rank), float(time))
                for abs_loss, rel_loss, rank, time in (pair.split(",") for pair in row)
            ]
            loss_lists.append(loss_list)
    return loss_lists


def matrix_completion(
    reference, ansatz, n_iter, step_size, low_rank_approach, coeff_steps=10
):

    criterion = nn.MSELoss(reduction="sum")
    optimizer = optim.SGD(
        ansatz.parameters(), lr=step_size, momentum=0.0, weight_decay=0.0
    )
    ansatz.train()

    loss_values = []
    coef_steps = coeff_steps
    if low_rank_approach in [1] and coef_steps == 1:
        coef_steps = 2  #  BUG needs basis update step
    with tqdm(total=n_iter, desc="Matrix Completion") as pbar:
        for iter in range(n_iter):
            start_time = time.time()

            output = ansatz()
            loss = criterion(output, reference)
            if low_rank_approach == 5:  # AdaLora
                loss = loss + 1e-4 * ansatz.ortho_regularization()

            optimizer.zero_grad()
            loss.backward()
            if iter == 0:
                with torch.no_grad():
                    t = torch.norm(output - reference, p="fro") / torch.norm(
                        reference, p="fro"
                    )
                if low_rank_approach in [1, 2, 3, 4, 5]:  # Low Rank approaches
                    rank = ansatz.r
                elif low_rank_approach == 0:
                    rank = min(ansatz.W.shape[0], ansatz.W.shape[1])
                loss_values.append([loss.item(), t.detach().cpu().numpy(), rank, 0])

            if low_rank_approach in [1]:  # BUG
                if iter % coef_steps == 0:
                    ansatz.step(learning_rate=step_size, dlrt_step="basis")
                    # second evaluation for coefficient step
                    output = ansatz()
                    loss = criterion(output, reference)
                    optimizer.zero_grad()
                    loss.backward()
                else:
                    ansatz.step(learning_rate=step_size, dlrt_step="coefficients")
                    if iter % coef_steps == coef_steps - 1:
                        ansatz.step(learning_rate=step_size, dlrt_step="truncate")
            else:
                ansatz.step(
                    learning_rate=step_size
                )  # Here, the other BUG layer needs to get an addition argument
            pbar.update(1)
            # print(torch.norm(output, p="fro"))
            with torch.no_grad():
                t = torch.norm(output - reference, p="fro") / torch.norm(
                    reference, p="fro"
                )
            if low_rank_approach in [1, 2, 3, 4, 5]:  # BUG
                rank = ansatz.r
            elif low_rank_approach == 0:
                rank = min(ansatz.W.shape[0], ansatz.W.shape[1])

            iter_time = time.time() - start_time

            pbar.set_description(
                f"Training, Loss: {loss.item():.6f}, Fornorm: {t:.6f}, Learning rate: {optimizer.param_groups[0]['lr']:.6f}, Rank: {rank}, Iter time: {iter_time:.2f}s"
            )
            loss_values.append([loss.item(), t.detach().cpu().numpy(), rank, iter_time])
            if loss.item() < 1e-6:
                break
            if math.isnan(loss.item()):
                break
    return loss_values


if __name__ == "__main__":
    # main_run_all()
    plot_loss_graphs()
