import sys
import copy
import torch
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 300
import numpy as np

sys.path.append("../")
from neural_networks.FusionModel.fusion_methods.naive_fusion import NaiveFusion
from neural_networks import data_loader
from neural_networks.MLP import Deep_MLP, MLP
from neural_networks.FusionModel.fusion_model import FusionModel
from neural_networks.FusionModel.fusion_methods.partial_fusion import PartialFusion


train_loader, test_loader = data_loader.load_mnist()
test_data = []
for batch in train_loader:
    inputs = batch[0]
    test_data.append(inputs)
test_data = torch.cat(test_data, dim=0)[:1000]
criterion = None #torch.nn.CrossEntropyLoss()
save = False

feature_base = 'pcd' # 'weight', 'activation'
SPECIALIST = True
act=2

acc_fuses = []
acc_naives = []
acc_as = []
acc_bs = []

for i in range(5):
    model_a = Deep_MLP(hidden_size_1=100, hidden_size_2=100, hidden_size_3=100, which_act=act) #Deep_MLP()
    model_b = Deep_MLP(hidden_size_1=100, hidden_size_2=100, hidden_size_3=100, which_act=act) #Deep_MLP()
    if save:
        model_a.train_model_best_ckpt(train_loader, test_loader, epochs=5)
        model_b.train_model_best_ckpt(train_loader, test_loader, epochs=5)

        model_a.save_model(f'saved/model_a_{i}')
        model_b.save_model(f'saved/model_b_{i}')
    else:
        if SPECIALIST:
            model_b.load_model('saved_compression/'+str(i)+'deepmlpmnist_general_'+str([100, 100, 100])+'_'+str(act)+'.checkpoint') #model_a.load_model(f'saved/model_a_{i}')
            model_a.load_model('saved_compression/'+str(i)+'deepmlpmnist_specific_'+str([100, 100, 100])+'_'+str(act)+'.checkpoint') #model_b.load_model(f'saved/model_b_{i}')
        else:
            model_a.load_model('saved_compression/' + str(2*i) + 'deepmlpmnist_' + str([100, 100, 100]) + '_' + str(act) + '.checkpoint')
            model_b.load_model('saved_compression/' + str(2*i+1) + 'deepmlpmnist_' + str([100, 100, 100]) + '_' + str(act) + '.checkpoint')

    test_a = model_a.test_model(test_loader, criterion=criterion)
    test_b = model_b.test_model(test_loader, criterion=criterion)
    # lambda values and model accuracies
    lambdas = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
    acc_a = [test_a] * len(lambdas)
    acc_b = [test_b] * len(lambdas)
    acc_as.append(acc_a)
    acc_bs.append(acc_b)

    # Fused model accuracies
    acc_fuse = {}
    acc_naive = []
    for l in lambdas:
        naive_model = FusionModel(model_a, model_b, NaiveFusion(), lambdas=[1 - l, l])
        acc_naive.append(naive_model.test_model(test_loader, criterion=criterion))
        for alpha in [0, 0.2, 0.4, 0.5, 0.6, 0.8, 1]:
            if feature_base == 'pcd':
                fused_model = FusionModel(
                    model_a, model_b,
                    PartialFusion(alphas=alpha, combine_costs=True),
                    lambdas=[1 - l, l],
                    pgd=True,
                )
            elif feature_base == 'weight':
                fused_model = FusionModel(
                    model_a, model_b,
                    PartialFusion(alphas=alpha),
                    lambdas=[1 - l, l],
                )
            elif feature_base == 'activation':
                fused_model = FusionModel(
                    model_a, model_b,
                    PartialFusion(alphas=alpha),
                    lambdas=[1 - l, l],
                    data=test_data,
                )
            else:
                raise NotImplementedError
            accuracy = fused_model.test_model(test_loader, verbose=False, criterion=criterion)
            acc_fuse.setdefault(alpha, []).append(accuracy)
            print(fused_model.get_total_weights())
            print(fused_model.non_zero_weights)
            print(l, alpha, accuracy)

    acc_fuses.append(acc_fuse)
    acc_naives.append(acc_naive)
    acc_as.append(acc_a)
    acc_bs.append(acc_b)

acc_naives = np.mean(np.array(acc_naives), axis=0)
acc_as = np.mean(np.array(acc_as), axis=0)
acc_bs = np.mean(np.array(acc_bs), axis=0)
print(acc_naives)
print(acc_as)
print(acc_bs)

sink_weights_list = sorted(acc_fuses[0].keys())
results = {}
for i, alpha in enumerate(sink_weights_list):
    acc_fuse_values = []
    for run in range(len(acc_fuses)):
        acc_fuse_values.append(acc_fuses[run][alpha])
    acc_fuse_values = np.mean(np.array(acc_fuse_values), axis=0)
    print(alpha, acc_fuse_values)
    results[alpha] = acc_fuse_values
print(results)