import random
import numpy as np
from tqdm import tqdm
import time
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt


def seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class LinearClassifier(nn.Module):
    def __init__(self, input_dim, output_dim=2):
        super(LinearClassifier, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = self.linear(x)
        return x


def evaluate(x, y, x_test, y_test, model):
    model.eval()
    with torch.no_grad():
        output = model(x)
        _, predicted = torch.max(output.data, 1)
        train_acc = (y == predicted).sum() / y.shape[0]

        output = model(x_test)
        _, predicted = torch.max(output.data, 1)
        test_acc = (y_test == predicted).sum() / y_test.shape[0]
    return train_acc.item(), test_acc.item()


def train(x, y, x_test, y_test, device, n_epoch=1000, eval_epoch=1, plot=False):
    model = LinearClassifier(input_dim=x.shape[1])
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1)
    x, y, x_test, y_test, model = map(lambda x: x.to(device), (x, y, x_test, y_test, model))
    train_loss_curve = []
    train_acc_curve = []
    test_acc_curve = []
    for epoch in range(n_epoch):
        model.train()
        output = model(x)
        loss = criterion(output, y)
        train_loss_curve.append(loss.item())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if epoch % eval_epoch == 0 or epoch == n_epoch - 1:
            train_acc, test_acc = evaluate(x, y, x_test, y_test, model)
            train_acc_curve.append(train_acc)
            test_acc_curve.append(test_acc)
            # print(f"epoch: {epoch}, loss: {loss:.3f} train acc {train_acc:.4f} test acc {test_acc:.4f}")

        if epoch == n_epoch - 1 and plot:
            plt.subplot(1, 2, 1)
            plt.plot(train_loss_curve)
            plt.subplot(1, 2, 2)
            plt.plot(train_acc_curve)
            plt.plot(test_acc_curve)
            plt.show()

    return model, test_acc


# teacher : x2 as input, student: x1 as input
def train_kd(x2, x1, y1, x1_test, y1_test, teacher_model, weight, device, n_epoch=1000, eval_epoch=1, plot=False):
    teacher_model.eval()
    model = LinearClassifier(input_dim=x1.shape[1])
    criterion = torch.nn.CrossEntropyLoss()
    criterion_pl = torch.nn.KLDivLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1)

    x2, x1, y1, x1_test, y1_test, teacher_model, model = map(lambda x: x.to(device),
                                                             (x2, x1, y1, x1_test, y1_test, teacher_model, model))
    train_loss_curve = []
    train_acc_curve = []
    test_acc_curve = []
    for epoch in range(n_epoch):
        model.train()
        pl = teacher_model(x2)
        output = model(x1)

        loss_gt, loss_pl = criterion(output, y1), criterion_pl(torch.log_softmax(output, dim=1),
                                                               torch.softmax(pl, dim=1))
        loss = weight[0] * loss_gt + weight[1] * loss_pl
        train_loss_curve.append(loss.item())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if epoch % eval_epoch == 0 or epoch == n_epoch - 1:
            # t_train_acc, t_test_acc = evaluate(xs, y, xs_test, y_test, teacher_model)
            train_acc, test_acc = evaluate(x1, y1, x1_test, y1_test, model)
            train_acc_curve.append(train_acc)
            test_acc_curve.append(test_acc)
            # print(f"teacher train acc {t_train_acc:.4f} test acc {t_test_acc:.4f}")
            # print(f"epoch: {epoch}, loss_gt: {loss_gt:.3f}, loss_pl: {loss_pl:.3f}, train acc {train_acc:.4f} test acc {test_acc:.4f}")

        if epoch == n_epoch - 1 and plot:
            plt.subplot(1, 2, 1)
            plt.plot(train_loss_curve)
            plt.subplot(1, 2, 2)
            plt.plot(train_acc_curve)
            plt.plot(test_acc_curve)
            plt.show()

    return test_acc


def gen_mm_data(a, n, mode, x1_dim=-1, x2_dim=-1, y_dim=-1, y2_dim=-1, overlap_dim=-1):
    if mode == 'a1':
        # increase 1-alpha
        x1_dim = 50
        x2_dim = 10
        y_dim = 10
        # overlap dim from 0 to 10, 1 - alpha = overlap_dim / y_dim

        xs = np.random.randn(n, y_dim)
        a = a[0:y_dim]
        y = (np.dot(xs, a) > 0).ravel()

        x1 = np.random.randn(n, x1_dim)
        x1[:, 0:y_dim] = xs[:, 0:y_dim]

        x2 = np.random.randn(n, x2_dim)
        x2[:, 0:overlap_dim] = xs[:, 0:overlap_dim]

    elif mode == 'a2':
        # label relevant info, fix overlap area (y_dim), increase y2_dim (β), no help for distillation
        x1_dim = 50
        x2_dim = 50
        y_dim = 10
        # y2_dim from 0 to 40, beta = y2_dim / (y_dim+y2_dim)

        xs = np.random.randn(n, x1_dim)
        a = a[0:x1_dim]
        y = (np.dot(xs, a) > 0).ravel()

        x1 = np.random.randn(n, x1_dim)
        x1[:, 0:y_dim] = xs[:, 0:y_dim]

        x2 = np.random.randn(n, x2_dim)
        x2[:, 0:y_dim + y2_dim] = xs[:, 0:y_dim + y2_dim]

    elif mode == 'b1':
        # modality overlap, fix x2, increase x1∩x2 area
        x1_dim = 50
        x2_dim = 25
        y_dim = 10
        y2_dim = 10
        # overlap dim from 0 to 10

        xs = np.random.randn(n, y_dim + y2_dim)
        a = a[0:y_dim + y2_dim]
        y = (np.dot(xs, a) > 0).ravel()

        x1 = np.random.randn(n, x1_dim)
        x1[:, 0:y_dim] = xs[:, 0:y_dim]

        x2 = np.random.randn(n, x2_dim)
        x2[:, y_dim - overlap_dim:y2_dim - overlap_dim + y_dim] = xs[:,
                                                                  y_dim - overlap_dim:y2_dim - overlap_dim + y_dim]
        # x2 = xs[:, y_dim-overlap_dim:y2_dim-overlap_dim+y_dim]

    elif mode == 'b2':
        x1_dim = 50
        x2_dim = 50
        xs = np.random.randn(n, y_dim)
        a = a[0:y_dim]
        y = (np.dot(xs, a) > 0).ravel()

        x1 = np.random.randn(n, x1_dim)
        x1[:, 0:y_dim] = xs[:, 0:y_dim]

        x2 = np.random.randn(n, x2_dim)
        x2[:, 0:overlap_dim] = x1[:, 0:overlap_dim]

    elif mode == 'c':
        # label irrelevant info (noise)
        x1_dim = 50
        y_dim = 10
        # x2_dim from 10 to 50

        xs = np.random.randn(n, y_dim)
        a = a[0:y_dim]
        y = (np.dot(xs, a) > 0).ravel()

        x1 = np.random.randn(n, x1_dim)
        x1[:, 0:y_dim] = xs

        x2 = np.random.randn(n, x2_dim)
        x2[:, 0:y_dim] = xs

    elif mode == 'remove_beta':
        x1_dim = 50
        x2_dim = 50
        xs = np.random.randn(n, y_dim + y2_dim)
        a = a[0: y_dim + y2_dim]
        y = (np.dot(xs, a) > 0).ravel()

        x1 = np.random.randn(n, x1_dim)
        x1[:, 0:y_dim] = xs[:, 0:y_dim]

        x2 = np.random.randn(n, x2_dim)
        x2[:, 0:y_dim + y2_dim] = xs

        beta = y_dim / (y_dim + y2_dim)
        # print(f'Finish generating data, beta = {beta}')

    return torch.Tensor(x1), torch.Tensor(x2), torch.LongTensor(y)


def run_mm(seed_num, mode, y_dim, y2_dim, overlap_dim, modify=False):
    d = 500
    n_train = 200
    n_test = 1000
    a = np.random.randn(d)
    seed(seed_num)
    x1, x2, y = gen_mm_data(a, n_train, mode, y_dim=y_dim, y2_dim=y2_dim, overlap_dim=overlap_dim)  # 200×50
    seed(seed_num + 1)
    x1_test, x2_test, y_test = gen_mm_data(a, n_test, mode, y_dim=y_dim, y2_dim=y2_dim, overlap_dim=overlap_dim)  # 1000×50

    if modify:
        beta_idx = range(y_dim, y_dim + y2_dim)
        x2[:, beta_idx] = 0

    # teacher : x2 as input, student: x1 as input
    teacher_model, teacher_acc = train(x2, y, x2_test, y_test, device=device, n_epoch=1000)
    _, naive_student_acc = train(x1, y, x1_test, y_test, device=device, n_epoch=1000)
    kd_student_acc = train_kd(x2, x1, y, x1_test, y_test, teacher_model, [1.0, 1.0], device)

    return teacher_acc, naive_student_acc, kd_student_acc


def exp1(seed, n_runs):
    overlap_dim_list = [0, 2, 4, 6, 8, 10]
    for overlap_dim in overlap_dim_list:
        gamma = overlap_dim / (20 - overlap_dim)
        acc_np = np.zeros((n_runs, 3))
        for i in range(n_runs):
            acc_np[i, :] = run_mm(seed + i, 'b1', -1, -1, overlap_dim)
        delta = np.round((acc_np[:, 2] - acc_np[:, 1]) * 100, 2)
        log_mean = np.mean(acc_np, axis=0) * 100
        print(f'gamma = {gamma:.2f}')
        print(f'Teacher acc {log_mean[0]:.2f}')
        print(f'Naive student acc {log_mean[1]:.2f}')
        print(f'KD student acc {log_mean[2]:.2f}')
        print(f'Delta: {np.mean(delta):.2f} ± {np.std(delta):.2f}')
        print('-' * 60)


def exp2(seed, n_runs):
    y2_dim_list = [0, 10, 20, 30, 40]

    for y2_dim in y2_dim_list:
        alpha_2 = y2_dim / (y2_dim + 10)
        acc_np = np.zeros((n_runs, 3))
        for i in range(n_runs):
            acc_np[i, :] = run_mm(seed + i, 'a2', -1, y2_dim, -1)
        delta = np.round((acc_np[:, 2] - acc_np[:, 1]) * 100, 2)
        log_mean = np.mean(acc_np, axis=0) * 100
        print(f'alpha = {alpha_2}')
        print(f'Teacher acc {log_mean[0]:.2f}')
        print(f'Naive student acc {log_mean[1]:.2f}')
        print(f'KD student acc {log_mean[2]:.2f}')
        print(f'Delta: {np.mean(delta):.2f} ± {np.std(delta):.2f}')
        print('-' * 60)


if __name__ == '__main__':
    # In the main paper, $x_a$ is teacher modality and $x_b$ is student modality;
    # In the experiments here, $x_2$ is teacher modality and $x_1$ is student modality.
    device = torch.device("cpu")
    exp1(seed=0, n_runs=10)
    exp2(seed=0, n_runs=10)
