import os
from network.get_network import GetNetwork
from utils.optimizer import SAM
from utils import swa_utils
from data.pacs_dataset import denormalization
import torch
from configs.default import *
import torch.nn.functional as F
from tqdm import tqdm
import random
from hessian_eigenthings import compute_hessian_eigenthings
from copy import deepcopy

def Shuffle_Batch_Data(data_in):
    len_total = len(data_in)
    idx_list = list(range(len_total))
    random.shuffle(idx_list)
    return data_in[idx_list]


def MultiClassCrossEntropy(logits, labels, T):
    logits, labels = logits / T, labels / T
    outputs, labels = F.log_softmax(logits, dim=1), F.softmax(labels, dim=1)
    # Sum across all classes and take the mean across all instances in the batch
    outputs = -torch.mean(torch.sum(outputs * labels, dim=1), dim=0)

    return outputs


def epoch_site_train(epochs, site_name, model, optimzier, scheduler, dataloader, log_ten, metric):
    model.train()
    for i, data_list in enumerate(dataloader):
        imgs, labels, domain_labels = data_list
        imgs = imgs.cuda()
        labels = labels.cuda()
        domain_labels = domain_labels.cuda()
        optimzier.zero_grad()
        output = model(imgs)
        loss = F.cross_entropy(output, labels)
        loss.backward()
        optimzier.step()
        log_ten.add_scalar(f'{site_name}_train/loss', loss.item(), epochs*len(dataloader)+i)
        metric.update(output, labels)
    
    log_ten.add_scalar(f'{site_name}_train/acc', metric.results()['acc'], epochs)
    if scheduler is not None:
        scheduler.step()


def epoch_site_train_sam(epochs, site_name, model, optimizer, scheduler, dataloader, log_ten, metric, rho, eta):
    minimizer = SAM(optimizer, model, rho, eta)

    model.train()
    for i, data_list in enumerate(dataloader):
        imgs, labels, domain_labels = data_list
        imgs = imgs.cuda()
        labels = labels.cuda()

        output = model(imgs)
        loss = F.cross_entropy(output, labels)
        loss.backward()
        minimizer.ascent_step()

        F.cross_entropy(model(imgs), labels).backward()
        minimizer.descent_step()

        log_ten.add_scalar(f'{site_name}_train/loss', loss.item(), epochs * len(dataloader) + i)
        metric.update(output, labels)

    log_ten.add_scalar(f'{site_name}_train/acc', metric.results()['acc'], epochs)
    scheduler.step()


def epoch_site_train_lwf(epochs, site_name, model, model_central, optimzier, scheduler, dataloader, log_ten, metric, lamda, T):
    model.train()
    model_central.train()
    for i, data_list in enumerate(dataloader):
        imgs, labels, domain_labels = data_list
        imgs = imgs.cuda()
        labels = labels.cuda()

        optimzier.zero_grad()
        output = model(imgs)
        with torch.no_grad():
            labels_old = model_central(imgs)

        loss_new = F.cross_entropy(output, labels)
        loss_old = MultiClassCrossEntropy(output, labels_old, T=T)
        loss = loss_new + lamda * loss_old

        loss.backward()
        optimzier.step()

        log_ten.add_scalar(f'{site_name}_train/loss_old', loss_old.item(), epochs * len(dataloader) + i)
        log_ten.add_scalar(f'{site_name}_train/loss_new', loss_new.item(), epochs * len(dataloader) + i)
        log_ten.add_scalar(f'{site_name}_train/loss', loss.item(), epochs * len(dataloader) + i)
        metric.update(output, labels)

    log_ten.add_scalar(f'{site_name}_train/acc', metric.results()['acc'], epochs)
    scheduler.step()


def epoch_site_train_flwf(epochs, site_name, model, model_central, optimizer, scheduler, dataloader, log_ten, metric, rho, eta, lamda, T):
    minimizer = SAM(optimizer, model, rho, eta)

    model.train()
    model_central.train()
    for i, data_list in enumerate(dataloader):
        imgs, labels, domain_labels = data_list
        imgs = imgs.cuda()
        labels = labels.cuda()

        output = model(imgs)
        with torch.no_grad():
            labels_old = model_central(imgs)

        loss_new = F.cross_entropy(output, labels)
        loss_old = MultiClassCrossEntropy(output, labels_old, T=T)
        loss = loss_new + lamda * loss_old
        loss.backward()
        minimizer.ascent_step()

        output = model(imgs)
        loss_new = F.cross_entropy(output, labels)
        loss_old = MultiClassCrossEntropy(output, labels_old, T=T)
        loss = loss_new + lamda * loss_old
        loss.backward()
        minimizer.descent_step()

        log_ten.add_scalar(f'{site_name}_train/loss_old', loss_old.item(), epochs * len(dataloader) + i)
        log_ten.add_scalar(f'{site_name}_train/loss_new', loss_new.item(), epochs * len(dataloader) + i)
        log_ten.add_scalar(f'{site_name}_train/loss', loss.item(), epochs * len(dataloader) + i)
        metric.update(output, labels)

    log_ten.add_scalar(f'{site_name}_train/acc', metric.results()['acc'], epochs)
    scheduler.step()


def site_train(comm_rounds, site_name, args, model, optimizer, scheduler, dataloader, log_ten, metric):
    tbar = tqdm(range(args.local_epochs))

    model_central = deepcopy(model)
    for local_epoch in tbar:
        tbar.set_description(f'{site_name}_train')
        if args.sam:
            epoch_site_train_sam(comm_rounds*args.local_epochs + local_epoch, site_name, model, optimizer, scheduler, dataloader, log_ten, metric, args.rho, args.eta)
        elif args.lwf:
            epoch_site_train_lwf(comm_rounds*args.local_epochs + local_epoch, site_name, model, model_central, optimizer, scheduler, dataloader, log_ten, metric, args.lamda, args.T)
        elif args.flwf:
            epoch_site_train_flwf(comm_rounds*args.local_epochs + local_epoch, site_name, model, model_central, optimizer, scheduler, dataloader, log_ten, metric, args.rho, args.eta, args.lamda, args.T)
        else:
            epoch_site_train(comm_rounds*args.local_epochs + local_epoch, site_name, model, optimizer, scheduler, dataloader, log_ten, metric)


def epoch_site_train_ta(epochs, site_name, objective, optimizer, optimizer_aug, scheduler, n_inner, dataloader, log_ten, metric):
    for i, data_list in enumerate(dataloader):
        imgs, labels, domain_labels = data_list
        imgs = imgs.cuda()
        labels = labels.cuda()
        context = labels

        if i % n_inner == 0:
            optimizer_aug.zero_grad()
            loss_aug, res = objective(imgs, labels, context, 'aug')
            loss_aug.backward()
            optimizer_aug.step()

        optimizer.zero_grad()
        with torch.no_grad():
            aug_x, _ = objective.trainable_aug(imgs, context)
            inputs = torch.stack([objective.base_aug(_x) for _x in aug_x])
            # inputs = torch.stack([objective.base_aug(_x) for _x in imgs])

        # calculate loss
        pred = objective.model(inputs)
        loss = F.cross_entropy(pred, labels)
        loss.backward()
        optimizer.step()

        log_ten.add_scalar(f'{site_name}_train/loss_adv', res['loss adv.'], epochs * len(dataloader) + i)
        log_ten.add_scalar(f'{site_name}_train/loss_teacher', res['loss teacher'], epochs * len(dataloader) + i)
        log_ten.add_scalar(f'{site_name}_train/loss_color', res['color reg.'], epochs * len(dataloader) + i)
        log_ten.add_scalar(f'{site_name}_train/acc_student', res['acc.'], epochs * len(dataloader) + i)
        log_ten.add_scalar(f'{site_name}_train/acc_teacher', res['acc. teacher'], epochs * len(dataloader) + i)
        log_ten.add_scalar(f'{site_name}_train/loss', loss.item(), epochs * len(dataloader) + i)
        metric.update(pred, labels)

    log_ten.add_scalar(f'{site_name}_train/acc', metric.results()['acc'], epochs)
    log_ten.add_image(f'{site_name}/before', imgs[0], epochs)
    log_ten.add_image(f'{site_name}/after', aug_x[0], epochs)
    # log_ten.add_image(f'{site_name}/after', denormalization(inputs[0]), epochs)
    scheduler.step()


def epoch_site_train_fta(epochs, site_name, objective, optimizer, optimizer_aug, scheduler, n_inner, dataloader, log_ten, metric, rho, eta):
    minimizer = SAM(optimizer, objective.model, rho, eta)
    for i, data_list in enumerate(dataloader):
        imgs, labels, domain_labels = data_list
        imgs = imgs.cuda()
        labels = labels.cuda()
        context = labels

        if i % n_inner == 0:
            optimizer_aug.zero_grad()
            loss_aug, res = objective(imgs, labels, context, 'aug')
            loss_aug.backward()
            optimizer_aug.step()

        with torch.no_grad():
            aug_x, _ = objective.trainable_aug(imgs, context)
            inputs = torch.stack([objective.base_aug(_x) for _x in aug_x])
            # inputs = torch.stack([objective.base_aug(_x) for _x in imgs])

        # calculate loss
        pred = objective.model(inputs)
        loss = F.cross_entropy(pred, labels)
        minimizer.optimizer.zero_grad()
        loss.backward()
        minimizer.ascent_step()

        F.cross_entropy(objective.model(inputs), labels).backward()
        minimizer.descent_step()

        log_ten.add_scalar(f'{site_name}_train/loss_adv', res['loss adv.'], epochs * len(dataloader) + i)
        log_ten.add_scalar(f'{site_name}_train/loss_teacher', res['loss teacher'], epochs * len(dataloader) + i)
        log_ten.add_scalar(f'{site_name}_train/loss_color', res['color reg.'], epochs * len(dataloader) + i)
        log_ten.add_scalar(f'{site_name}_train/acc_student', res['acc.'], epochs * len(dataloader) + i)
        log_ten.add_scalar(f'{site_name}_train/acc_teacher', res['acc. teacher'], epochs * len(dataloader) + i)
        log_ten.add_scalar(f'{site_name}_train/loss', loss.item(), epochs * len(dataloader) + i)
        metric.update(pred, labels)

    log_ten.add_scalar(f'{site_name}_train/acc', metric.results()['acc'], epochs)
    log_ten.add_image(f'{site_name}/before', imgs[0], epochs)
    log_ten.add_image(f'{site_name}/after', aug_x[0], epochs)
    # log_ten.add_image(f'{site_name}/after', denormalization(inputs[0]), epochs)
    scheduler.step()


def site_train_ta(comm_rounds, site_name, args, objective, optimizer, optimizer_aug, scheduler, dataloader, log_ten, metric):
    tbar = tqdm(range(args.local_epochs))
    for local_epoch in tbar:
        tbar.set_description(f'{site_name}_train')
        if args.ta:
            epoch_site_train_ta(comm_rounds*args.local_epochs + local_epoch, site_name, objective, optimizer, optimizer_aug, scheduler, args.n_inner, dataloader, log_ten, metric)
        elif args.fta:
            epoch_site_train_fta(comm_rounds*args.local_epochs + local_epoch, site_name, objective, optimizer, optimizer_aug, scheduler, args.n_inner, dataloader, log_ten, metric, args.rho, args.eta)


def epoch_site_train_swad(epochs, step, site_name, args, objective, optimizer, optimizer_aug, scheduler, n_inner, dataloader, loader_val, swad_algorithm, swad, log_ten, metric):
    flat_end = False
    check_freq = len(dataloader) * args.local_epochs // 25
    for i, data_list in enumerate(dataloader):
        imgs, labels, domain_labels = data_list
        imgs = imgs.cuda()
        labels = labels.cuda()
        context = labels

        if i % n_inner == 0:
            optimizer_aug.zero_grad()
            loss_aug, res = objective(imgs, labels, context, 'aug')
            loss_aug.backward()
            optimizer_aug.step()

        optimizer.zero_grad()
        with torch.no_grad():
            aug_x, _ = objective.trainable_aug(imgs, context)
            inputs = torch.stack([objective.base_aug(_x) for _x in aug_x])

        # calculate loss
        pred = objective.model(inputs)
        loss = F.cross_entropy(pred, labels)
        loss.backward()
        optimizer.step()

        log_ten.add_scalar(f'{site_name}_train/loss_adv', res['loss adv.'], epochs * len(dataloader) + i)
        log_ten.add_scalar(f'{site_name}_train/loss_teacher', res['loss teacher'], epochs * len(dataloader) + i)
        log_ten.add_scalar(f'{site_name}_train/loss_color', res['color reg.'], epochs * len(dataloader) + i)
        log_ten.add_scalar(f'{site_name}_train/acc_student', res['acc.'], epochs * len(dataloader) + i)
        log_ten.add_scalar(f'{site_name}_train/acc_teacher', res['acc. teacher'], epochs * len(dataloader) + i)
        log_ten.add_scalar(f'{site_name}_train/loss', loss.item(), epochs * len(dataloader) + i)
        metric.update(pred, labels)

        swad_algorithm.update_parameters(deepcopy(objective.model), step=step)

        if step % check_freq == 0:
            val_metric = deepcopy(metric)
            val_metric.init()
            results_dict = site_evaluation(step, site_name, args, objective.model, loader_val, None, log_ten,
                                           val_metric, 'val', False)
            objective.model.train()
            swad.update_and_evaluate(swad_algorithm, results_dict["acc"], results_dict["loss"])

            if swad.dead_valley:
                print("SWAD valley is dead -> early stop !")
                flat_end = True
                break
            swad_algorithm = swa_utils.AveragedModel(objective.model)  # reset

        step += 1

    log_ten.add_scalar(f'{site_name}_train/acc', metric.results()['acc'], epochs)
    log_ten.add_image(f'{site_name}/before', imgs[0], epochs)
    log_ten.add_image(f'{site_name}/after', aug_x[0], epochs)
    scheduler.step()

    return step, flat_end, swad_algorithm, swad


def site_train_swad(comm_rounds, site_name, args, objective, optimizer, optimizer_aug, scheduler, dataloader, loader_val, log_ten, metric):
    tbar = tqdm(range(args.local_epochs))
    step = 0
    swad_algorithm = swa_utils.AveragedModel(objective.model)
    swad = swa_utils.LossValley(n_converge=args.n_converge, n_tolerance=args.n_tolerance, tolerance_ratio=0.3)

    for local_epoch in tbar:
        tbar.set_description(f'{site_name}_train')
        step, flat_end, swad_algorithm, swad = epoch_site_train_swad(
            comm_rounds*args.local_epochs + local_epoch, step, site_name, args, objective, optimizer, optimizer_aug,
            scheduler, args.n_inner, dataloader, loader_val, swad_algorithm, swad, log_ten, metric)
        if flat_end:
            break

    swad_algorithm = swad.get_final_model()
    start = swad_algorithm.start_step
    end = swad_algorithm.end_step
    print(f" [{start}-{end}]  (N={swad_algorithm.n_averaged})")

    return swad_algorithm.module


def site_evaluation(epochs, site_name, args, model, dataloader, log_file, log_ten, metric, note='after_fed', log=True):
    model.eval()
    with torch.no_grad():
        for imgs, labels, domain_labels, in dataloader:
            imgs = imgs.cuda()
            output = model(imgs)
            metric.update(output, labels)
    results_dict = metric.results()

    if log:
        log_ten.add_scalar(f'{note}_{site_name}/loss', results_dict['loss'], epochs)
        log_ten.add_scalar(f'{note}_{site_name}/acc', results_dict['acc'], epochs)
        log_file.info(f'{note} Round: {epochs:3d} | Epochs: {args.local_epochs*epochs:3d} | Domain: {site_name} | loss: {results_dict["loss"]:.4f} | Acc: {results_dict["acc"]*100:.2f}%')
    else:
        log_ten.add_scalar(f'{note}_{site_name}/loss', results_dict['loss'], epochs)
        log_ten.add_scalar(f'{note}_{site_name}/acc', results_dict['acc'], epochs)
        # print(f'{note} Steps: {epochs:3d} | Domain: {site_name} | loss: {results_dict["loss"]:.4f} | Acc: {results_dict["acc"]*100:.2f}%')

    if args.flat:
        print('Begin eigen decomposition.')
        dataloader.dataset.flag = 1
        eigenvals, _ = compute_hessian_eigenthings(
            deepcopy(model), deepcopy(dataloader), torch.nn.CrossEntropyLoss(), 5, max_possible_gpu_samples=4096)
        dataloader.dataset.flag = 0
        log_ten.add_scalar(f'flatness/{note}_{site_name}_max', eigenvals[0], epochs)
        log_ten.add_scalar(f'flatness/{note}_{site_name}_ratio', eigenvals[0]/eigenvals[4], epochs)
        print('Top-5 eigenvalues: {}'.format(eigenvals))

    return results_dict

def site_evaluation_class_level(epochs, site_name, args, model, dataloader, log_file, log_ten, metric, note='after_fed'):
    model.eval()
    with torch.no_grad():
        for imgs, labels, domain_labels, in dataloader:
            imgs = imgs.cuda()
            output = model(imgs)
            metric.update(output, labels)
    results_dict = metric.results()
    log_ten.add_scalar(f'{note}_{site_name}_loss', results_dict['loss'], epochs)
    log_ten.add_scalar(f'{note}_{site_name}_acc', results_dict['acc'], epochs)
    log_ten.add_scalar(f'{note}_{site_name}_class_acc', results_dict['class_level_acc'], epochs)
    log_file.info(f'{note} Round: {epochs:3d} | Epochs: {args.local_epochs*epochs:3d} | Domain: {site_name} | loss: {results_dict["loss"]:.4f} | Acc: {results_dict["acc"]*100:.2f}% | C Acc: {results_dict["class_level_acc"]*100:.2f}%')

    return results_dict

def site_only_evaluation(model, dataloader, metric):
    model.eval()
    with torch.no_grad():
        for imgs, labels, domain_labels, in dataloader:
            imgs = imgs.cuda()
            output = model(imgs)
            metric.update(output, labels)
    results_dict = metric.results()
    return results_dict

def GetFedModel(args, num_classes, is_train=True):
    global_model, feature_level = GetNetwork(args, args.num_classes, True)
    global_model = global_model.cuda()
    model_dict = {}
    optimizer_dict = {}
    scheduler_dict = {}
    
    if args.dataset == 'pacs':
        domain_list = pacs_domain_list
    elif args.dataset == 'officehome':
        domain_list = officehome_domain_list
    elif args.dataset == 'terrainc':
        domain_list = terra_incognita_list
    elif args.dataset == 'digits':
        domain_list = digits_domain_list
        
    for domain_name in domain_list:
        model_dict[domain_name], _ = GetNetwork(args, num_classes, is_train)
        model_dict[domain_name] = model_dict[domain_name].cuda()
        optimizer_dict[domain_name] = torch.optim.SGD(model_dict[domain_name].parameters(), lr=args.lr, momentum=0.9,
                                                      weight_decay=5e-4)
        total_epochs = args.local_epochs * args.comm
        if args.lr_policy == 'step':
            scheduler_dict[domain_name] = torch.optim.lr_scheduler.StepLR(optimizer_dict[domain_name], step_size=int(total_epochs * 0.8), gamma=0.1)
        elif args.lr_policy == 'mul_step':
            scheduler_dict[domain_name] = torch.optim.lr_scheduler.MultiStepLR(optimizer_dict[domain_name], milestones=[int(total_epochs*0.3), int(total_epochs*0.8)], gamma=0.1)
        elif args.lr_policy == 'exp95':
            scheduler_dict[domain_name] = torch.optim.lr_scheduler.ExponentialLR(optimizer_dict[domain_name], gamma=0.95)
        elif args.lr_policy == 'exp98':
            scheduler_dict[domain_name] = torch.optim.lr_scheduler.ExponentialLR(optimizer_dict[domain_name], gamma=0.98)
        elif args.lr_policy == 'exp99':
            scheduler_dict[domain_name] = torch.optim.lr_scheduler.ExponentialLR(optimizer_dict[domain_name], gamma=0.99)   
        elif args.lr_policy == 'cos':
            scheduler_dict[domain_name] = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_dict[domain_name], T_max=total_epochs)
        elif args.lr_policy == 'none':
            scheduler_dict[domain_name] = None
            
    return global_model, model_dict, optimizer_dict, scheduler_dict

def SaveCheckPoint(args, model, epochs, path, optimizer=None, schedule=None, note='best_val'):
    check_dict = {'args':args, 'epochs':epochs, 'model':model.state_dict(), 'note': note}
    if optimizer is not None:
        check_dict['optimizer'] = optimizer.state_dict()
    if schedule is not None:
        check_dict['shceduler'] = schedule.state_dict()
    if not os.path.isdir(path):
        os.makedirs(path)
        
    torch.save(check_dict, os.path.join(path, note+'.pt'))
    

def SaveAug(aug, path, note=''):
    torch.save(aug.state_dict(), os.path.join(path, note+'.pt'))
