import copy
import numpy as np
from multiprocessing import Pool
from functools import partial

from munch import Munch
import torch
import torch.nn as nn
from torch.nn import functional as F
import pytorch_lightning as pl

from train import LightningModel
from datasets import DatasetBase
import kd


# Avoid printing in scientific notation
np.set_printoptions(suppress=True)
torch.set_printoptions(sci_mode=False)

# Generate data
def generate_data(d=10, N=20, r=0.5):
    mu = r * torch.stack((torch.ones(d), -torch.ones(d)))
    Z = torch.randint(2, (N, ))
    X = torch.randn(N, d) + mu[Z]

    p_unnormalized = torch.exp(-torch.norm(X - mu[:, None], dim=-1)**2 / 2)
    p_normalized = p_unnormalized / p_unnormalized.sum(dim=0)
    dot = mu @ X.t()
    pstar = torch.softmax(dot, dim=0)
    assert torch.allclose(p_normalized, pstar)

    Y = (torch.rand(N) < pstar[1]).long()
    print(Y)


class LinearLightningModel(LightningModel):
    def __init__(self, model_cfg, dataset_cfg, train_cfg):
        pl.LightningModule.__init__(self)
        self.dataset_cfg = dataset_cfg
        self.dataset = DatasetBase.registry[dataset_cfg.name](dataset_cfg)
        self.train_cfg = train_cfg
        self.model_cfg = model_cfg
        self.model = nn.Linear(dataset_cfg.input_size, dataset_cfg.output_size, bias=model_cfg.bias)

    def training_step(self, batch, batch_idx):
        batch_x, batch_y = batch
        out = self.forward(batch_x)
        loss = self.dataset.loss(out, batch_y) + self.train_cfg.l2 * torch.norm(self.model.weight)**2
        metrics = self.dataset.metrics(out, batch_y)
        return {'loss': loss, 'size': batch_x.shape[0], 'out': out, 'target': batch_y,
                'progress_bar': metrics, 'log': metrics}

class BayesLightningModel(LinearLightningModel):

    def validation_step(self, batch, batch_idx, prefix='val'):
        batch_x, batch_y = batch
        dot = batch_x @ self.dataset.mu.t()
        # pstar = torch.softmax(dot, dim=-1)
        # teacher_out = pstar
        out = dot
        if getattr(self.model_cfg.kd, 'unsmoothing', 0.0) != 0.0:
            pstar = torch.softmax(dot, dim=-1)
            teacher_round = (dot >= 0).float()
            new_teacher_prob = (1 - self.model_cfg.kd.unsmoothing) * pstar + self.model_cfg.kd.unsmoothing * teacher_round
            out = torch.log(new_teacher_prob)
        loss = self.dataset.loss(out, batch_y)
        metrics = self.dataset.metrics(out, batch_y)
        return {'size': batch_x.shape[0], 'loss': loss, 'out': out, 'target': batch_y, **metrics}

    def configure_optimizers(self):
        return None

class DistillLightningModel(LinearLightningModel):
    def __init__(self, model_cfg, dataset_cfg, train_cfg):
        super().__init__(model_cfg, dataset_cfg, train_cfg)
        kdloss_cls = getattr(kd_utils, self.train_cfg.kd['class'])
        kdloss_params = {k: v for k, v in self.train_cfg.kd.items() if k != 'class'}
        self.kd_loss = kdloss_cls(**kdloss_params)

    def training_step(self, batch, batch_idx):
        batch_x, batch_y = batch
        with torch.no_grad():
            dot = batch_x @ self.dataset.mu.t()
            teacher_out = torch.log_softmax(dot, dim=-1)
            # teacher_out = dot
            if getattr(self.model_cfg.kd, 'unsmoothing', 0.0) != 0.0:
                pstar = torch.softmax(dot, dim=-1)
                teacher_round = (dot >= 0).float()
                new_teacher_prob = (1 - self.model_cfg.kd.unsmoothing) * pstar + self.model_cfg.kd.unsmoothing * teacher_round
                teacher_out = torch.log(new_teacher_prob)
            if getattr(self.model_cfg.kd, 'power', 1.0) != 1.0:
                pstar = torch.softmax(dot, dim=-1)
                new_teacher_prob = pstar ** self.model_cfg.kd.power
                teacher_out = torch.log(new_teacher_prob)
        out = self.model(batch_x)
        loss_og = self.dataset.loss(out, batch_y)
        loss = self.kd_loss(out, teacher_out, batch_y, loss_og) + self.train_cfg.l2 * torch.norm(self.model.weight)**2
        metrics = self.dataset.metrics(out, batch_y)
        return {'loss': loss, 'size': batch_x.shape[0], 'out': out, 'target': batch_y,
                'progress_bar': metrics, 'log': metrics}

    # def validation_step(self, batch, batch_idx, prefix='val'):
    #     batch_x, batch_y = batch
    #     dot = batch_x @ self.dataset.mu.t()
    #     pstar = torch.softmax(dot, dim=-1)
    #     teacher_out = pstar
    #     out = teacher_out
    #     loss = self.dataset.loss(out, batch_y)
    #     metrics = self.dataset.metrics(out, batch_y)
    #     return {'size': batch_x.shape[0], 'loss': loss, 'out': out, 'target': batch_y, **metrics}

def linear_train(model_cfg, dataset_cfg, train_cfg):
    model = LinearLightningModel(model_cfg, dataset_cfg, train_cfg)
    trainer = pl.Trainer(max_epochs=train_cfg.epochs,
        progress_bar_refresh_rate=0,
        check_val_every_n_epoch=5,
        early_stop_callback=False,
        )
    trainer.fit(model)
    return model._val_results['val_loss'], model._val_results['val_accuracy']

def student_train(model_cfg, dataset_cfg, train_cfg, bestinit=False):
    student = DistillLightningModel(model_cfg, dataset_cfg, train_cfg)
    if bestinit:
        student.prepare_data()
        student.model.weight = nn.Parameter(student.dataset.mu.detach().clone())
    trainer = pl.Trainer(max_epochs=train_cfg.epochs,
        progress_bar_refresh_rate=0,
        check_val_every_n_epoch=5,
        early_stop_callback=False,
        )
    trainer.fit(student)
    return student._val_results['val_loss'], student._val_results['val_accuracy']

N = 20
ntrials = 5
dataset_cfg = Munch({'name': 'gmm',
                     'input_size': 10,
                     'output_size': 2,
                     'mean_distance_scale': 0.5,
                     'train_length': N,
                     'val_length': 1000,
                     'num_workers': 0,  # Much faster to set it to 0
                     'seed': 12345
                    })
# dataset = DatasetBase.registry[dataset_cfg.name](dataset_cfg)
# dataset.prepare_data()
# dataset.prepare_dataloader(batch_size=N)

model_cfg = Munch({'name': 'linear', 'bias': True,
                   'kd': Munch({'unsmoothing': 0.0})
                  })  # Very important that bias=False

train_cfg = Munch({
    'batch_size': N,
    'epochs': 40,
    'l2': 1e-4,
    'kd': Munch({'class': 'KDLoss', 'temperature': 1.0, 'alpha': 1.0}),
    'optimizer': Munch({'class': 'SGD', 'lr': 3e-1}),
    # 'optimizer': Munch({'class': 'Adam', 'lr': 1e-2, 'weight_decay': 1e-5}),
    # 'optimizer': Munch({'class': 'LBFGS', 'lr': 1e-2}),
    'verbose': False
})

linear_results = np.array([linear_train(model_cfg, dataset_cfg, train_cfg) for _ in range(ntrials)])
print(np.mean(linear_results, axis=0), np.std(linear_results, axis=0))

bayes = BayesLightningModel(model_cfg, dataset_cfg, train_cfg)
bayes.prepare_data()
trainer = pl.Trainer()
trainer.test(bayes, bayes.dataset.val_loader)
bayes_accs = [(bayes._test_results['test_loss'], bayes._test_results['test_accuracy'])]

student_accs = np.array([student_train(model_cfg, dataset_cfg, train_cfg) for _ in range(ntrials)])
print(np.mean(student_accs, axis=0), np.std(student_accs, axis=0))

student_bestinit_accs = np.array([student_train(model_cfg, dataset_cfg, train_cfg, bestinit=True) for _ in range(ntrials)])
print(np.mean(student_bestinit_accs, axis=0), np.std(student_bestinit_accs, axis=0))

# student.model.weight = nn.Parameter(student.dataset.mu)
# trainer.test(student, student.dataset.val_loader)

# Teacher being more confident, i.e. shrink away from uniform

unsmoothing_results = []
for unsmoothing in [0.0, 0.5, 0.75, 0.9]:
    new_model_cfg = copy.deepcopy(model_cfg)
    new_model_cfg.kd.unsmoothing = unsmoothing
    unsmoothing_results.append(np.array([student_train(new_model_cfg, dataset_cfg, train_cfg) for _ in range(ntrials)]))
unsmoothing_results = np.array(unsmoothing_results)
print(np.hstack((np.mean(unsmoothing_results, axis=1), np.std(unsmoothing_results, axis=1))))

for unsmoothing in [0.0, 0.5, 0.75, 0.9]:
    new_model_cfg = copy.deepcopy(model_cfg)
    new_model_cfg.kd.unsmoothing = unsmoothing
    bayes = BayesLightningModel(new_model_cfg, dataset_cfg, train_cfg)
    bayes.prepare_data()
    trainer = pl.Trainer()
    trainer.test(bayes, bayes.dataset.val_loader)
    bayes_accs = [(bayes._test_results['test_loss'], bayes._test_results['test_accuracy'])]
    print(f'#######Unsmoothing: {unsmoothing}')
    print(bayes_accs)

# MSE loss
mse_results = []
for unsmoothing in [0.0, 0.5, 0.75, 0.9]:
    new_model_cfg = copy.deepcopy(model_cfg)
    new_model_cfg.kd.unsmoothing = unsmoothing
    new_train_cfg = copy.deepcopy(train_cfg)
    new_train_cfg.kd = Munch({'class': 'KDMSELoss', 'alpha': 1.0})
    result = np.array([student_train(new_model_cfg, dataset_cfg, new_train_cfg) for _ in range(ntrials)])
    mse_results.append(result)
mse_results = np.array(mse_results)
print(np.hstack((np.mean(mse_results, axis=1), np.std(mse_results, axis=1))))

# Orthogonal loss
mseortho_results = []
for unsmoothing in [0.0, 0.5, 0.75, 0.9]:
    new_model_cfg = copy.deepcopy(model_cfg)
    new_model_cfg.kd.unsmoothing = unsmoothing
    new_train_cfg = copy.deepcopy(train_cfg)
    new_train_cfg.kd = Munch({'class': 'KDMSEOrthoLoss', 'alpha': 1.0})
    mseortho_results.append(np.array([student_train(new_model_cfg, dataset_cfg, new_train_cfg) for _ in range(ntrials)]))
mseortho_results = np.array(mseortho_results)
print(np.hstack((np.mean(mseortho_results, axis=1), np.std(mseortho_results, axis=1))))

mseorthovarred_results = []
for unsmoothing in [0.0, 0.5, 0.75, 0.9]:
    new_model_cfg = copy.deepcopy(model_cfg)
    new_model_cfg.kd.unsmoothing = unsmoothing
    new_train_cfg = copy.deepcopy(train_cfg)
    new_train_cfg.kd = Munch({'class': 'KDMSEVarRedOrthoLoss', 'alpha': 1.0})
    result = np.array([student_train(new_model_cfg, dataset_cfg, new_train_cfg) for _ in range(ntrials)])
    mseorthovarred_results.append(result)
mseorthovarred_results = np.array(mseorthovarred_results)
print(np.hstack((np.mean(mseorthovarred_results, axis=1), np.std(mseorthovarred_results, axis=1))))

# Teacher being more confident, with power > 1.0 (i.e. temperature < 1.0)
for power in [1.0, 2.0, 4.0, 8.0]:
    new_model_cfg = copy.deepcopy(model_cfg)
    new_model_cfg.kd.power = power
    student_results = np.array([student_train(new_model_cfg, dataset_cfg, train_cfg) for _ in range(ntrials)])
    print(f'#######Power: {power}')
    print(np.mean(student_results, axis=0), np.std(student_results, axis=0))

mseorthovarred1_results = []
for unsmoothing in [0.0, 0.5, 0.75, 0.9]:
    new_model_cfg = copy.deepcopy(model_cfg)
    new_model_cfg.kd.unsmoothing = unsmoothing
    new_train_cfg = copy.deepcopy(train_cfg)
    new_train_cfg.kd = Munch({'class': 'KDMSEVarRedOrthoLoss1', 'alpha': 1.0})
    result = np.array([student_train(new_model_cfg, dataset_cfg, new_train_cfg) for _ in range(ntrials)])
    mseorthovarred1_results.append(result)
mseorthovarred1_results = np.array(mseorthovarred1_results)
print(np.hstack((np.mean(mseorthovarred1_results, axis=1), np.std(mseorthovarred1_results, axis=1))))

mseorthovarred4_results = []
for unsmoothing in [0.0, 0.5, 0.75, 0.9]:
    new_model_cfg = copy.deepcopy(model_cfg)
    new_model_cfg.kd.unsmoothing = unsmoothing
    new_train_cfg = copy.deepcopy(train_cfg)
    new_train_cfg.kd = Munch({'class': 'KDMSEVarRedOrthoLoss4', 'alpha': 1.0})
    result = np.array([student_train(new_model_cfg, dataset_cfg, new_train_cfg) for _ in range(ntrials)])
    mseorthovarred4_results.append(result)
mseorthovarred4_results = np.array(mseorthovarred4_results)
print(np.hstack((np.mean(mseorthovarred4_results, axis=1), np.std(mseorthovarred4_results, axis=1))))

new_train_cfg.verbose=True
new_train_cfg.epochs=40
new_train_cfg.kd = Munch({'class': 'KDMSEVarRedOrthoLoss2', 'alpha': 1.0})
new_train_cfg.kd = Munch({'class': 'KDMSEVarRedOrthoLoss3', 'alpha': 1.0})
new_train_cfg.kd = Munch({'class': 'KDMSEVarRedOrthoLoss4', 'alpha': 1.0})
