# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
"""
Train and eval functions used in main.py
"""
import math
import sys
from typing import Iterable, Optional

import torch

from timm.data import Mixup
from timm.utils import accuracy, ModelEma

from losses import DistillationLoss
import utils
import torch.nn.functional as F
import numpy as np
import json
from scipy import linalg
import copy
all_loss = []
def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
                    model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
                    set_training_mode=True):
    #print(data_loader.dataset.indices)
    model.train(set_training_mode)
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10
    alpha = 1
    beta = 1
    loss_0 = []
    for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
        targets = targets.type(torch.int64)
        #print(type(samples))
        # print(targets)
        # print(index)
        # print("#######")
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)

        with torch.cuda.amp.autocast():
            # outputs = model(samples)
            intermediate_output = model.module.get_intermediate_layers(samples, 1)
            # query = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
            text_feature = np.load("./text_features.npy")
            aug_features = aggregate_att(intermediate_output, text_feature, 1.2, 2, 0.001)
            # intermediate_features = model.module.forward_features(samples)
            # query = model.module.forward_head(intermediate_features, True)
            # align_loss = InfiniteContrastiveV2(targets, query)
            forward_feature = alpha*intermediate_output + (1-alpha)*aug_features
            outputs = model.module.forward_head(forward_feature, True)
            loss = criterion(samples, outputs, targets)
            # overall_loss = alpha*loss + beta*align_loss

        # loss_value = alpha*loss.item() + beta*align_loss
        loss_value = alpha*loss.item()
        #print(loss_value)
        loss_0.append(loss_value)
        #print(align_loss)

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        optimizer.zero_grad()

        # this attribute is added by timm on one optimizer (adahessian)
        is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
        loss_scaler(loss, optimizer, clip_grad=max_norm,
                    parameters=model.parameters(), create_graph=is_second_order)

        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)

        metric_logger.update(loss=loss_value)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
    # gather the stats from all processes
    #print(loss_0)
    all_loss.append(loss_0)
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

def aggregate_att(w_clients, w_server, stepsize, metric, dp):
    w_next = copy.deepcopy(w_server)
    w_clients = w_clients
    w_server = w_server
    att, att_mat = {}, {}
    for k in range(len(w_server)):
        w_next[k] = torch.zeros_like(w_server[k]).cpu()
        att[k] = torch.zeros(len(w_clients)).cpu()
    for k in range(len(w_server)):
        # for i in range(0, len(w_clients)):
        att[k] = torch.from_numpy(np.array(linalg.norm(w_server[k]-w_clients[k], ord=metric)))
    for k in range(len(w_server)):
        att[k] = F.softmax(att[k], dim=0)
    for k in range(len(w_server)):
        att_weight = torch.zeros_like(w_server[k])
        # for i in range(0, len(w_clients)):
        att_weight += torch.mul(w_server[k]-w_clients[k], att[k])
        w_next[k] = w_server[k] - torch.mul(att_weight, stepsize) + torch.mul(torch.randn(w_server[k].shape), dp)
    return w_next

def save_loss(datapath):
    ar=np.array(all_loss)
    print(ar.shape)
    print("save")
    savepath = datapath + "/loss_round.npy"
    np.save(savepath, ar)

def InfiniteContrastiveV2(label, query):
    T = 1
    means_filename = '/data/liuyijiang/zhangrongyu/ActiveFT/output_boundary/cifar_class_means_embedding_boundary.npy'
    covs_filename = '/data/liuyijiang/zhangrongyu/ActiveFT/output_boundary/cifar_class_covs_embedding_boundary.npy'
    feat_means = torch.load(means_filename).cuda()
    feat_covs = torch.load(covs_filename).cuda()

    # feature_filename = '/data/liuyijiang/zhangrongyu/ActiveFT/data_selection/features/cifar_caption_embedding_numpy.npy'
    # feature = torch.load(feature_filename)
    # feat_means = []
    # for i in range(len(index)):
    #     feat_means.append(feature[index[i]])
    # feat_means = torch.stack(tuple(feat_means),dim=1).cuda()[0]
    # print(feat_means.size())
    # print(query.size())
    # for test only
    # feat_covs = torch.load(means_filename).cuda()
    # 128, 768
    # print(query.size())
    # 1000, 768
    # print(feat_means.size())
    # 128, 1000
    query_mean = query.mm(feat_means.permute(1,0).float()) #N*K
    feat_covs = feat_covs / T
    # 128, 1000
    query_cov_query = 0.5*query.pow(2).mm(feat_covs.permute(1,0))
    # print(query_cov_query.size())
    logits = query_mean + query_cov_query
    # apply temperature
    logits /= T
    # label: 128
    ce_loss = F.cross_entropy(logits, label, reduction='none')

    key_covs = feat_covs[label]
    jcl_loss = (0.5 * torch.sum(query.pow(2).mul(key_covs), dim=1)) / T
    loss = ce_loss + jcl_loss
    # ce_loss = F.cross_entropy(feat_means, label, reduction='none')
    return loss.mean()

# def InfiniteContrastiveV3(label, query):
#     filename = '/data/liuyijiang/zhangrongyu/ActiveFT/output_boundary/cifar-10.npy'
#     means, covs = text_embedding(filename=filename)
#     T = 0.07
#     query_mean = query.mm(means.permute(1,0).float()) #N*K
#     covs = covs / T
#     query_cov_query = 0.5*query.pow(2).mm(covs.permute(1,0))
#     logits = query_mean + query_cov_query

#     # apply temperature
#     logits /= T
#     ce_loss = F.cross_entropy(logits, label, reduction='none')

#     key_covs = covs[label]
#     jcl_loss = (0.5 * torch.sum(query.pow(2).mul(key_covs), dim=1)) / T
#     # return (F.cross_entropy(logits, labels, reduction='none')*mask).mean() + jcl_loss
#     loss = ce_loss + jcl_loss
#     return loss.mean()

def text_embedding(filename):
    all_text_features = torch.load(filename)

    all_means = []
    all_covs = []
    for c in range(all_text_features.size(0)):
        cls_text_features = all_text_features[c].cpu().numpy()
        mean = np.mean(cls_text_features, axis=0) # 1024 
        cov = np.cov(cls_text_features.T) # 1024 x 1024
        mean = torch.from_numpy(mean)
        cov = torch.from_numpy(cov)
        cov = cov.diag()
        all_means.append(mean)
        all_covs.append(cov)
    means = torch.stack(tuple(all_means), dim=0).float().cuda()
    covs = torch.stack(tuple(all_covs), dim=0).float().cuda()
    return means, covs

@torch.no_grad()
def evaluate(data_loader, model, device):
    criterion = torch.nn.CrossEntropyLoss()

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'

    # switch to evaluation mode
    model.eval()

    for images, target  in metric_logger.log_every(data_loader, 10, header):
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        # compute output
        with torch.cuda.amp.autocast():
            output = model(images)
            loss = criterion(output, target)

        acc1, acc5 = accuracy(output, target, topk=(1, 5))

        batch_size = images.shape[0]
        metric_logger.update(loss=loss.item())
        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
          .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))

    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
