import os
import torch
from datetime import datetime
import csv

from .plotting_utils import *


class ExperimentInfo:
    def __init__(
        self,
        version,
        RUN,
        description,
        data_dir,
        output_dir,
        device,
        target,
        exp_type,
    ):
        self._version = version
        self._description = description
        self._data_dir = data_dir
        self._exp_type = exp_type
        self._output_dir = os.path.join(
            output_dir,
            target.upper(),
            exp_type,
            datetime.now().strftime("%Y%m%d-%H%M%S"),
        )
        self._device = device

        torch.set_default_dtype(torch.float32)
        os.makedirs(self._output_dir)
        os.chmod(self._output_dir, 0o7777)

        seed = RUN + 12
        if seed:
            torch.manual_seed(seed)

        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    @property
    def save_path(self):
        return self._output_dir

    def log_pretraining(self, scores):
        pretrain_save_path = os.path.join(self._output_dir, "pretrain")

        os.makedirs(pretrain_save_path)
        os.chmod(self._output_dir, 0o7777)

        epochs = scores["Epochs"]
        train_losses = scores["train_losses"]
        train_accuracy = scores["train_accuracy"]

        test_losses = scores["test_losses"]
        test_accuracy = scores["test_accuracy"]

        aux_losses = scores["aux_losses"]

        with open(
            os.path.join(pretrain_save_path, "results.csv"), "w", newline=""
        ) as csv_file:
            fieldnames = [
                "Epoch",
                "Train Loss",
                "Train Accuracy",
                "Test Loss",
                "Test Accuracy",
                "Aux Losses",
            ]
            writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
            writer.writeheader()
            for e, l, a, l_t, a_t, aux in zip(
                epochs,
                train_losses,
                train_accuracy,
                test_losses,
                test_accuracy,
                aux_losses,
            ):
                writer.writerow(
                    {
                        "Epoch": e,
                        "Train Loss": l,
                        "Train Accuracy": a,
                        "Test Loss": l_t,
                        "Test Accuracy": a_t,
                        "Aux Losses": aux,
                    }
                )

        fig = plot_loss_curve(
            epochs, train_losses, test_losses, train_accuracy, test_accuracy
        )
        fig_path = os.path.join(pretrain_save_path, "loss_curve" + ".png")
        fig.savefig(fig_path)

        fig = plot_aux_curve(epochs, aux_losses)
        fig_path = os.path.join(pretrain_save_path, "aux_curve" + ".png")
        fig.savefig(fig_path)

    def log_one_shot(self, scores):
        one_shot_dir = os.path.join(self._output_dir, "1shot")

        os.makedirs(one_shot_dir)
        os.chmod(one_shot_dir, 0o7777)

        sparsity = list(scores.keys())
        losses = [score["loss"] for score in scores.values()]
        accuracies = [score["accuracy"] for score in scores.values()]
        times = [score["time_taken"] for score in scores.values()]

        with open(
            os.path.join(one_shot_dir, "results.csv"), "w", newline=""
        ) as csv_file:
            fieldnames = ["Sparsity", "Loss", "Time"]
            for n in range(len(accuracies)):
                fieldnames.append(f"Accuracy {n}")
            writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
            writer.writeheader()
            for s, l, a, t in zip(sparsity, losses, accuracies, times):
                row = {"Sparsity": s, "Loss": l, "Time": t}
                for n, accuracy in enumerate(a):
                    row[f"Accuracy {n}"] = accuracy
                writer.writerow(row)

        # fig1 = plot_curve(sparsity, accuracy, xlabel="Sparsity", ylabel="Accuracy")
        # fig_path = os.path.join(one_shot_dir, "accuracy" + ".png")
        # fig1.savefig(fig_path)

        fig2 = plot_curve(sparsity, losses, xlabel="Sparsity", ylabel="Loss")
        fig_path = os.path.join(one_shot_dir, "loss" + ".png")
        fig2.savefig(fig_path)

        fig3 = plot_curve(
            sparsity[1:], times[1:], xlabel="Sparsity", ylabel="Pruning Time"
        )
        fig_path = os.path.join(one_shot_dir, "time" + ".png")
        fig3.savefig(fig_path)

    def log_one_shot_FT(self, train_scores, val_scores):
        one_shotFT_dir = os.path.join(self._output_dir, "1shot+FT")

        os.makedirs(one_shotFT_dir)
        os.chmod(one_shotFT_dir, 0o7777)

        epochs = list(train_scores.keys())
        train_losses = [score["loss"] for score in train_scores.values()]
        train_accuracies = [score["accuracy"] for score in train_scores.values()]

        test_losses = [score["loss"] for score in val_scores.values()]
        test_accuracies = [score["accuracy"] for score in val_scores.values()]

        sparsity = [score["sparsity"] for score in train_scores.values()]
        times = [score["time_taken"] for score in train_scores.values()]

        with open(
            os.path.join(one_shotFT_dir, "results.csv"), "w", newline=""
        ) as csv_file:
            fieldnames = ["Epoch", "Train Loss", "Test Loss", "Sparsity", "Time"]
            for n in range(len(train_accuracies)):
                fieldnames.append(f"Train Accuracy {n}")
            for n in range(len(test_accuracies)):
                fieldnames.append(f"Test Accuracy {n}")

            writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
            writer.writeheader()
            for e, l, a, l_t, a_t, s, clk in zip(
                epochs,
                train_losses,
                train_accuracies,
                test_losses,
                test_accuracies,
                sparsity,
                times,
            ):
                row = {
                    "Epoch": e,
                    "Train Loss": l,
                    "Test Loss": l_t,
                    "Sparsity": s,
                    "Time": clk,
                }
                for n, accuracy in enumerate(a):
                    row[f"Train Accuracy {n}"] = accuracy

                for n, accuracy in enumerate(a_t):
                    row[f"Test Accuracy {n}"] = accuracy

                writer.writerow(row)

        # fig1 = plot_loss_curve(
        #     epochs, train_losses, test_losses, train_accuracy, test_accuracy
        # )
        # fig_path = os.path.join(one_shotFT_dir, "loss_curve" + ".png")
        # fig1.savefig(fig_path)

        fig2 = plot_curve(epochs, sparsity, xlabel="Epochs", ylabel="Sparsity")
        fig_path = os.path.join(one_shotFT_dir, "sparseloss_curve" + ".png")
        fig2.savefig(fig_path)

    def log_gradual(self, train_scores, val_scores):
        gradual_dir = os.path.join(self._output_dir, "gradual-prune")

        os.makedirs(gradual_dir)
        os.chmod(gradual_dir, 0o7777)

        epochs = list(train_scores.keys())
        train_losses = [score["loss"] for score in train_scores.values()]
        train_accuracies = [score["accuracy"] for score in train_scores.values()]

        test_losses = [score["loss"] for score in val_scores.values()]
        test_accuracies = [score["accuracy"] for score in val_scores.values()]

        sparsity = [score["sparsity"] for score in train_scores.values()]
        times = [score["time_taken"] for score in train_scores.values()]

        with open(
            os.path.join(gradual_dir, "results.csv"), "w", newline=""
        ) as csv_file:
            fieldnames = ["Epoch", "Train Loss", "Test Loss", "Sparsity", "Time"]
            for n in range(len(train_accuracies)):
                fieldnames.append(f"Train Accuracy {n}")
            for n in range(len(test_accuracies)):
                fieldnames.append(f"Test Accuracy {n}")

            writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
            writer.writeheader()
            for e, l, a, l_t, a_t, s, clk in zip(
                epochs,
                train_losses,
                train_accuracies,
                test_losses,
                test_accuracies,
                sparsity,
                times,
            ):
                row = {
                    "Epoch": e,
                    "Train Loss": l,
                    "Test Loss": l_t,
                    "Sparsity": s,
                    "Time": clk,
                }
                for n, accuracy in enumerate(a):
                    row[f"Train Accuracy {n}"] = accuracy

                for n, accuracy in enumerate(a_t):
                    row[f"Test Accuracy {n}"] = accuracy

                writer.writerow(row)

        # fig1 = plot_loss_curve(
        #     epochs, train_losses, test_losses, train_accuracy, test_accuracy
        # )
        # fig_path = os.path.join(gradual_dir, "loss_curve" + ".png")
        # fig1.savefig(fig_path)

        fig2 = plot_curve(epochs, sparsity, xlabel="Epochs", ylabel="Sparsity")
        fig_path = os.path.join(gradual_dir, "sparseloss_curve" + ".png")
        fig2.savefig(fig_path)
