import math

import numpy as np
import matplotlib.pyplot as plt
import torch


MTEs = [2 ** i for i in range(2, 8)]

data_model = "weirder" # choices: decay, weird, easy, weirder
decay_alpha = 1.1
weird_p = 0.5
weirder_alpha = 1.5
weirder_beta = 1.0
weirder_r = 1.5
weirder_p = 0.5
weirder_MST_exp = 1.25

batch_size = 1024
lr = 0.1
teacher_train_iters = 1000
student_train_iters = 1000
num_checkpoint_prints = 100

seed = 0
cuda = True

device = torch.device("cuda:0" if torch.cuda.is_available() and cuda else "cpu")


def train_teacher_student(MTE):

    torch.manual_seed(seed)

    # Define data distribution.
    if data_model == "decay":
        d = round(MTE ** decay_alpha)
        MST = 2 * d
        cov_eigvals = [(i+1) ** (-decay_alpha) for i in range(d)]
        ground_truth = torch.tensor([1/eigval for eigval in cov_eigvals], device=device)
    elif data_model == "weird":
        epsilon = MTE ** (-weird_p)
        num_ones = round((1 + epsilon) * MTE + 1)
        num_small = round(MTE ** (3/2))
        d = num_ones + num_small
        MST = 2 * d
        cov_eigvals = [1] * num_ones + [1/MTE] * num_small
        ground_truth = torch.tensor(
            [1] * num_ones + [math.sqrt(MTE)] * num_small, device=device
        )
    elif data_model == "easy":
        d = MTE
        MST = MTE
        cov_eigvals = [1.0] * d
        ground_truth = torch.tensor([1.0] + [0.0] * (d-1), device=device)
    elif data_model == "weirder":
        d = round(MTE ** weirder_r)
        K = round(MTE ** weirder_p)
        MST = round(d ** weirder_MST_exp)
        cov_eigvals = [1.0] * K + [1/MTE] * (d-K)
        ground_truth = torch.tensor(
            [MTE ** (weirder_alpha/2)] * K + [MTE ** ((1-weirder_beta)/2)] * (d-K),
            device=device
        )
    else:
        raise NotImplementedError

    data_std = torch.sqrt(torch.tensor(cov_eigvals, device=device))
    batch_mean = torch.zeros(batch_size, d, device=device)
    batch_std = data_std.view(1, d).expand(batch_size, d)
    def get_raw_batch():
        raw_x = torch.normal(mean=batch_mean, std=batch_std)
        raw_y = raw_x @ ground_truth
        return raw_x, raw_y

    teacher_projection = torch.normal(
        mean=torch.zeros(d, MTE, device=device),
        std=torch.ones(d, MTE, device=device) / math.sqrt(MTE)
    )
    def get_teacher_batch():
        raw_x, raw_y = get_raw_batch()
        projected_x = raw_x @ teacher_projection
        return projected_x, raw_y

    student_projection = torch.normal(
        mean=torch.zeros(d, MST, device=device),
        std=torch.ones(d, MST, device=device) / math.sqrt(MST)
    )
    def get_student_batch(teacher):
        raw_x, raw_y = get_raw_batch()
        projected_x = raw_x @ student_projection
        with torch.no_grad():
            teacher_y = teacher(raw_x @ teacher_projection).squeeze(1)
        return projected_x, teacher_y, raw_y

    print("\n\n--------------------------------")
    print(f"Running MTE={MTE}, d={d}, MST={MST}")
    print("--------------------------------\n")

    # Train teacher with online data.
    teacher = torch.nn.Linear(in_features=MTE, out_features=1, bias=False, device=device)
    teacher_optimizer = torch.optim.SGD(params=teacher.parameters(), lr=lr)
    print("Training teacher:")
    print_iters = [0] + [
        round((i + 1) * teacher_train_iters / num_checkpoint_prints) - 1
        for i in range(num_checkpoint_prints)
    ]
    teacher_losses = []
    for t in range(teacher_train_iters):
        teacher_optimizer.zero_grad()
        teacher_batch_x, teacher_batch_y = get_teacher_batch()
        teacher_pred = teacher(teacher_batch_x).squeeze(1)
        loss = torch.mean((teacher_pred - teacher_batch_y) ** 2)
        if t in print_iters:
            print(f"Iteration {t+1}/{teacher_train_iters}: Loss={loss.item():.5f}")

        loss.backward()
        teacher_optimizer.step()

        teacher_losses.append(loss.item())

        # Check for nans.
        if bool(torch.isnan(loss)):
            teacher_losses += [loss.item()] * (teacher_train_iters - 1 - t)
            print("Encountered nan!")
            break

    # Train student with online data, labels generated by teacher.
    student = torch.nn.Linear(in_features=MST, out_features=1, bias=False, device=device)
    student_optimizer = torch.optim.SGD(params=student.parameters(), lr=lr)
    print("\nTraining student:")
    print_iters = [0] + [
        round((i + 1) * student_train_iters / num_checkpoint_prints) - 1
        for i in range(num_checkpoint_prints)
    ]
    student_train_losses = []
    student_true_losses = []
    for t in range(student_train_iters):
        student_optimizer.zero_grad()
        student_batch_x, student_batch_y, student_true_y = get_student_batch(teacher)
        student_pred = student(student_batch_x).squeeze(1)
        loss = torch.mean((student_pred - student_batch_y) ** 2)
        with torch.no_grad():
            true_loss = torch.mean((student_pred - student_true_y) ** 2)
        if t in print_iters:
            print(f"Iteration {t+1}/{student_train_iters}: Training Loss={loss.item():.5f}, True Loss={true_loss.item():.5f}")

        loss.backward()
        student_optimizer.step()

        student_train_losses.append(loss.item())
        student_true_losses.append(true_loss.item())

        # Check for nans.
        if bool(torch.isnan(loss)):
            student_train_losses += [loss.item()] * (student_train_iters - 1 - t)
            student_true_losses += [true_loss.item()] * (student_train_iters - 1 - t)
            print("Encountered nan!")
            break

    return teacher_losses, student_train_losses, student_true_losses


def main():

    # Run teacher-student training with various d.
    all_teacher_losses = {}
    all_student_train_losses = {}
    all_student_true_losses = {}
    for MTE in MTEs:
        teacher_losses, student_train_losses, student_true_losses = train_teacher_student(MTE)
        all_teacher_losses[MTE] = np.array(teacher_losses)
        all_student_train_losses[MTE] = np.array(student_train_losses)
        all_student_true_losses[MTE] = np.array(student_true_losses)

    # Plot loss ratios.
    print("\n")
    for MTE in MTEs:
        LTE = np.min(all_teacher_losses[MTE])
        LSTs = all_student_true_losses[MTE]
        ys = LSTs / LTE
        plt.plot(np.arange(len(ys)), ys, label=f"MTE = {MTE}")
        print(f"MTE={MTE} LTE: {LTE}, LST(min): {np.min(LSTs)}, Ratio: {np.min(ys)}")
    plt.legend()
    plt.xlabel("Training Steps")
    plt.ylabel("Loss Ratio")
    plt.ylim([0, 1.5])
    plt.grid(True)
    plt.tight_layout()
    plt.savefig("wts.eps")
    plt.close()

    # Plot individual losses.
    for i, MTE in enumerate(MTEs):
        LTE = np.min(all_teacher_losses[MTE])
        iterations = len(all_teacher_losses[MTE])
        plt.plot(np.arange(iterations), all_teacher_losses[MTE], label=f"Teacher Loss")
        plt.plot(np.arange(iterations), all_student_train_losses[MTE], label=f"Student Train Loss")
        plt.plot(np.arange(iterations), all_student_true_losses[MTE], label=f"Student True Loss")
        plt.plot([0, iterations-1], [LTE, LTE], linestyle="--", color="black")
        plt.ylim([0, LTE * 2])
        plt.legend()
        plt.xlabel("Training Steps")
        plt.ylabel("MSE Loss")
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(f"{i}_{MTE}_losses.eps")
        plt.close()


if __name__ == "__main__":
    main()
