import math
from itertools import product

import numpy as np
import matplotlib.pyplot as plt
import torch


MSTs = [2 ** i for i in range(9, 14)]
MTEs = [2 ** i for i in range(4, 9)]

alpha = 1.5
beta = 1.0
r = 1.5
p = 0.5

batch_size = 1024
lr = 0.1
teacher_train_iters = 1000
student_train_iters = 1000
num_checkpoint_prints = 100

seed = 0
cuda = True

device = torch.device("cuda:0" if torch.cuda.is_available() and cuda else "cpu")


def train_teacher_student(m, MTE, MST):

    torch.manual_seed(seed)

    # Define data distribution.
    d = round(m ** r)
    K = round(m ** p)
    cov_eigvals = [1.0] * K + [1/MTE] * (d-K)
    ground_truth = torch.tensor(
        [m ** (alpha/2)] * K + [m ** ((1-beta)/2)] * (d-K),
        device=device
    )

    data_std = torch.sqrt(torch.tensor(cov_eigvals, device=device))
    batch_mean = torch.zeros(batch_size, d, device=device)
    batch_std = data_std.view(1, d).expand(batch_size, d)
    def get_raw_batch():
        raw_x = torch.normal(mean=batch_mean, std=batch_std)
        raw_y = raw_x @ ground_truth
        return raw_x, raw_y

    teacher_projection = torch.normal(
        mean=torch.zeros(d, MTE, device=device),
        std=torch.ones(d, MTE, device=device) / math.sqrt(MTE)
    )
    def get_teacher_batch():
        raw_x, raw_y = get_raw_batch()
        projected_x = raw_x @ teacher_projection
        return projected_x, raw_y

    student_projection = torch.normal(
        mean=torch.zeros(d, MST, device=device),
        std=torch.ones(d, MST, device=device) / math.sqrt(MST)
    )
    def get_student_batch(teacher):
        raw_x, raw_y = get_raw_batch()
        projected_x = raw_x @ student_projection
        with torch.no_grad():
            teacher_y = teacher(raw_x @ teacher_projection).squeeze(1)
        return projected_x, teacher_y, raw_y

    print("\n\n--------------------------------")
    print(f"Running MTE={MTE}, d={d}, MST={MST}")
    print("--------------------------------\n")

    # Train teacher with online data.
    teacher = torch.nn.Linear(in_features=MTE, out_features=1, bias=False, device=device)
    teacher_optimizer = torch.optim.SGD(params=teacher.parameters(), lr=lr)
    print("Training teacher:")
    print_iters = [0] + [
        round((i + 1) * teacher_train_iters / num_checkpoint_prints) - 1
        for i in range(num_checkpoint_prints)
    ]
    teacher_losses = []
    for t in range(teacher_train_iters):
        teacher_optimizer.zero_grad()
        teacher_batch_x, teacher_batch_y = get_teacher_batch()
        teacher_pred = teacher(teacher_batch_x).squeeze(1)
        loss = torch.mean((teacher_pred - teacher_batch_y) ** 2)
        if t in print_iters:
            print(f"Iteration {t+1}/{teacher_train_iters}: Loss={loss.item():.5f}")

        loss.backward()
        teacher_optimizer.step()

        teacher_losses.append(loss.item())

        # Check for nans.
        if bool(torch.isnan(loss)):
            teacher_losses += [loss.item()] * (teacher_train_iters - 1 - t)
            print("Encountered nan!")
            break

    # Train student with online data, labels generated by teacher.
    student = torch.nn.Linear(in_features=MST, out_features=1, bias=False, device=device)
    student_optimizer = torch.optim.SGD(params=student.parameters(), lr=lr)
    print("\nTraining student:")
    print_iters = [0] + [
        round((i + 1) * student_train_iters / num_checkpoint_prints) - 1
        for i in range(num_checkpoint_prints)
    ]
    student_train_losses = []
    student_true_losses = []
    for t in range(student_train_iters):
        student_optimizer.zero_grad()
        student_batch_x, student_batch_y, student_true_y = get_student_batch(teacher)
        student_pred = student(student_batch_x).squeeze(1)
        loss = torch.mean((student_pred - student_batch_y) ** 2)
        with torch.no_grad():
            true_loss = torch.mean((student_pred - student_true_y) ** 2)
        if t in print_iters:
            print(f"Iteration {t+1}/{student_train_iters}: Training Loss={loss.item():.5f}, True Loss={true_loss.item():.5f}")

        loss.backward()
        student_optimizer.step()

        student_train_losses.append(loss.item())
        student_true_losses.append(true_loss.item())

        # Check for nans.
        if bool(torch.isnan(loss)):
            student_train_losses += [loss.item()] * (student_train_iters - 1 - t)
            student_true_losses += [true_loss.item()] * (student_train_iters - 1 - t)
            print("Encountered nan!")
            break

    return teacher_losses, student_train_losses, student_true_losses


def evaluate_m(m):

    # Run teacher-student training with various d.
    all_teacher_losses = {}
    all_student_train_losses = {}
    all_student_true_losses = {}
    ratios = np.zeros((len(MTEs), len(MSTs)))
    for (i, MTE), (j, MST) in product(enumerate(MTEs), enumerate(MSTs)):
        teacher_losses, student_train_losses, student_true_losses = train_teacher_student(m, MTE, MST)
        all_teacher_losses[(MTE, MST)] = np.array(teacher_losses)
        all_student_train_losses[(MTE, MST)] = np.array(student_train_losses)
        all_student_true_losses[(MTE, MST)] = np.array(student_true_losses)

        LTE = np.min(teacher_losses)
        LST = np.min(student_true_losses)
        ratios[i, j] = LST / LTE

    # Plot loss ratios.
    row_labels = ["$2^{" + str(i) + "}$" for i in range(9, 14)]
    col_labels = ["$2^{" + str(i) + "}$" for i in range(4, 9)]
    plt.matshow(ratios, cmap="inferno")
    plt.xlabel("Student size")
    plt.ylabel("Teacher size")
    plt.xticks(range(len(MTEs)), row_labels)
    plt.yticks(range(len(MSTs)), col_labels)
    plt.colorbar()

    plt.tight_layout()
    plt.savefig(f"compute_ablation_{m}.eps")
    plt.close()


def main():
    evaluate_m(m=32)
    evaluate_m(m=64)


if __name__ == "__main__":
    main()
