# testing visilization
import argparse
from tqdm import tqdm

import torch
import torch.nn.functional as F

import torchvision
from torchvision import models, datasets, transforms
from torch import nn
from torch.utils.data import DataLoader


import matplotlib.pyplot as plt
import seaborn as sns

from FLAlgorithms.trainmodel.models import VGG16, MobileNetV2, VisionTransformer, ResNet18, MinimalDecoder, FeatureClF
import os
 
import copy
import numpy as np

import warnings
warnings.filterwarnings('ignore')

def pairwise_dist(A):
    # Taken frmo https://stackoverflow.com/questions/37009647/compute-pairwise-distance-in-a-batch-without-replicating-tensor-in-tensorflow
    #A = torch_print(A, [torch.reduce_sum(A)], message="A is")
    r = torch.sum(A*A, 1)
    r = torch.reshape(r, [-1, 1])
    rr = r.repeat(1,A.shape[0])
    rt = r.T.repeat(A.shape[0],1)
    D = torch.maximum(rr - 2*torch.matmul(A, A.T) + rt, 1e-7*torch.ones(A.shape[0], A.shape[0]).to(A.device))
    D = torch.sqrt(D)
    return D

def dist_corr(X, F):
    n = X.shape[0]
    a = pairwise_dist(X)
    b = pairwise_dist(F)

    A = a - torch.mean(a,1).repeat(a.shape[1],1).T - torch.mean(a,0).repeat(a.shape[0],1) + torch.mean(a)
    B = b - torch.mean(b,1).repeat(b.shape[1],1).T - torch.mean(b,0).repeat(b.shape[0],1) + torch.mean(b)
    dCovXY = torch.sqrt(torch.sum(A*B) / (n ** 2)+ 1e-7)
    dVarXX = torch.sqrt(torch.sum(A*A) / (n ** 2)+ 1e-7)
    dVarYY = torch.sqrt(torch.sum(B*B) / (n ** 2)+ 1e-7)
    dCorXY = dCovXY / (torch.sqrt(dVarXX + 1e-7) * torch.sqrt(dVarYY+ 1e-7) )
    return dCorXY


class Corelation(nn.Module):
    def __init__(self):
        super(Corelation, self).__init__()
    def forward(self, data, feaure):
        n = data.shape[0]
        loss = dist_corr(data.reshape(n,-1),feaure.reshape(n,-1))
        return loss/n
   
    
class NormalizeInverse(transforms.Normalize):
    """
    Undoes the normalization and returns the reconstructed images in the input domain.
    """

    def __init__(self, mean, std):
        mean = torch.as_tensor(mean)
        std = torch.as_tensor(std)
        std_inv = 1 / (std + 1e-7)
        mean_inv = -mean * std_inv
        super().__init__(mean=mean_inv, std=std_inv)

    def __call__(self, tensor):
        return super().__call__(tensor.clone())
        
class Normalize(transforms.Normalize):
    """
    Undoes the normalization and returns the reconstructed images in the input domain.
    """

    def __init__(self, mean, std):
        mean = torch.as_tensor(mean)
        std = torch.as_tensor(std)
        super().__init__(mean=mean, std=std)

    def __call__(self, tensor):
        return super().__call__(tensor.clone())        

def get_CIFAR10(root="./"):
    input_size = 32
    num_classes = 10
    mean, std = [0.49139968, 0.48215827, 0.44653124],[0.24703233, 0.24348505, 0.26158768]
    normalize = transforms.Normalize((mean), (std))
    
    train_transform = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            #normalize,
        ]
    )
    train_dataset = datasets.CIFAR10(
        root + "data/CIFAR10", train=True, transform=train_transform, download=True
    )

    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            #normalize,
        ]
    )
    test_dataset = datasets.CIFAR10(
        root + "data/CIFAR10", train=False, transform=test_transform, download=True
    )

    return input_size, num_classes, train_dataset, test_dataset


layer = 0
torch.manual_seed(0)
input_size, num_classes, train_dataset, test_dataset = get_CIFAR10()    
mean, std = [0.49139968, 0.48215827, 0.44653124],[0.24703233, 0.24348505, 0.26158768]
NI = NormalizeInverse(mean, std)  #inverse normalize
NM = Normalize(mean, std)         #normalizing
 

def train(model, train_loader, optimizer, epoch, feature_extractor):
    model.train()

    total_loss = []
    Loss = nn.CrossEntropyLoss()
    for data, target in tqdm(train_loader):
    
        data_n = torch.zeros(data.shape)
        for i in range(data.shape[0]):
            data_n[i,:,:,:] = NM(data[i,:,:,:])
        
        feature = feature_extractor.get_feature(data_n.cuda(), idx=layer)
        #logit = model(data.cuda())
        logit = model(feature)
        loss = Loss(logit, target.cuda())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss.append(loss.item())
   
    avg_loss = sum(total_loss) / len(total_loss)
    print(f"Epoch: {epoch}:")
    print(f"Train Set: Average Loss: {avg_loss:.3f}")
    return total_loss

def test(model, test_loader, feature_extractor):
    model.eval()
    Loss = nn.CrossEntropyLoss()
    total_loss = []
    recall = 0
    cnt = 0.0
    for data, target in test_loader:
        with torch.no_grad():
            data_n = torch.zeros(data.shape)
            for i in range(data.shape[0]):
                data_n[i,:,:,:] = NM(data[i,:,:,:])
           
            feature = feature_extractor.get_feature(data_n.cuda(), idx=layer)
            logit = model(feature)
            #logit = model(data.cuda())
            loss = Loss(logit, target.cuda()) 
            total_loss.append(loss.item())
            
            predictions = torch.argmax(logit, dim=1).cpu()
            # for i in range(len(data)):
                # if predictions[i] == 1 and target[i] == 1:
                    # recall += 1
            # cnt += len(data)/2.0
            recall += (torch.sum(predictions==target)).item()
            cnt += len(data)
            
            
    avg_loss = sum(total_loss) / len(total_loss)
    print(f"Testing Set: Average Loss: {avg_loss:.3f}", 'recall:', recall/cnt)

    
    return avg_loss
    
def add_privcay(data, type='None'): ##from CIFAR10-c
    if type == 'red_square':
        data[0,2:7,2:7] = 1
    return data

 
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", type=int, default=20, help="number of epochs to train (default: 50)")
    parser.add_argument("--lr", type=float, default=0.001, help="learning rate (default: 0.05)")
    parser.add_argument("--seed", type=int, default=1, help="random seed (default: 1)")
    args = parser.parse_args()
    print(args)

    
    
    ## visualise data
    plt.figure(figsize=(20, 10))
    data_all = torch.zeros([10,3,32,32])
    for i in range(10):
        data, label = test_dataset[i]  #label 3 [0,8]
        data = add_privcay(data, 'red_square')
        data_all[i,:,:,:] = data
        ax = plt.subplot(1,10, i+1)
        plt.imshow(transforms.ToPILImage()(data))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

    plt.savefig('data_examples.png')
     
    
    ######## load feature extractor 
    feature_extractor = MobileNetV2(10).cuda()
    model_path = os.path.join("../../models", 'FedNF')
    #checkpoint_path = os.path.join('/auto/homes/tx229/federated/FL_v4/models/Qua_3', "server_FedFea_MOBNET_Cifar10_loss_CE_CE_KL_epoch_10_100.pt")
    #checkpoint_path = os.path.join('/auto/homes/tx229/federated/FLea/models/saved','server_FLea_MOBNET_Cifar10_loss_CE_CE_DeC_KL_epoch_10_100.pt')
    #checkpoint_path = os.path.join('/auto/homes/tx229/federated/FLea/models/saved','server_FLea_test_MOBNET_Cifar10_loss_CE_Cor_KL_epoch_10_100.pt')
    checkpoint_path = os.path.join('/auto/homes/tx229/federated/FLea/models/test','c0.4_server_FLea_test_MOBNET_Cifar10_loss_MCE_DeC_KL_epoch_10_500_client_100_split_quantity_3.0.pt')
    
    
    print(checkpoint_path)
    feature_extractor = torch.load(checkpoint_path).cuda()
    print('Load model checkpoint from name succuessfully!') 
    feature_extractor.eval()

    # ## visualise feature, cat
    # data, label = test_dataset[0]
    # data = add_privcay(data, 'red_square')
    # data = NM(data)  #normalisation
    # data = data[None, :]  #add batch domian
    # feature = feature_extractor.get_feature(data.cuda(), idx=layer).cpu().detach()           
    # plt.savefig('feature_16.png')


        
    ## reconstruction model (use feature from normalised data)
    #model = FeatureClF(input_nc=3) #0-16, 1=24
    model = FeatureClF()
    print(model)
    model = model.cuda()
    
    kwargs = {"num_workers": 2, "pin_memory": True}

    ## make training and testing data
    TRAIN_len = 5000
    training_data = []
    for i in range(TRAIN_len):
        data, label = train_dataset[i]
        training_data.append([torch.empty_like(data).copy_(data), 0])
    for i in range(TRAIN_len, TRAIN_len*2):
        data, label = train_dataset[i]
        data = add_privcay(data, 'red_square') 
        training_data.append([data, 1])
 
    testing_data = []
    for i in range(int(len(test_dataset)/2.0)):
        data, label = test_dataset[i]
        testing_data.append([torch.empty_like(data).copy_(data), 0])
    for i in range(int(len(test_dataset)/2.0), len(test_dataset)):
        data, label = test_dataset[i]    
        data = add_privcay(data, 'red_square') 
        testing_data.append([data, 1])
            
    
    print('Training samples:', len(training_data), 'testing samples:', len(testing_data))
    train_loader = torch.utils.data.DataLoader(training_data, batch_size=32, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(testing_data, batch_size=128, shuffle=False, **kwargs)

    milestones = [25, 50, 80]
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

    loss_all = []
    for epoch in range(1, args.epochs + 1):
        total_loss = train(model, train_loader, optimizer, epoch, feature_extractor)
        test(model, test_loader, feature_extractor)
        scheduler.step()
        loss_all.extend(total_loss)
   
    

    
if __name__ == "__main__":
    main()