import os
from network.get_network import GetNetwork
import torch
from configs.default import *
import torch.nn.functional as F
from tqdm import tqdm
import random
import torch.distributions as distributions
import torch.nn as nn 

def Shuffle_Batch_Data(data_in):
    len_total = len(data_in)
    idx_list = list(range(len_total))
    random.shuffle(idx_list)
    return data_in[idx_list]

def epoch_site_train(epochs, site_name, model, optimzier, scheduler, dataloader, metric, args):
    # print("\nepoch_site_train")
    model.train()
    for i, data_list in enumerate(dataloader):
        imgs, labels, domain_labels = data_list
        imgs = imgs.cuda()
        labels = labels.cuda()
        domain_labels = domain_labels.cuda()
        cls = nn.Linear(args.z_dim, args.nbc)
        r_sigma = nn.Parameter(torch.ones(args.nbc,args.z_dim))
        r_mu = nn.Parameter(torch.zeros(args.nbc,args.z_dim))
        C = nn.Parameter(torch.ones([]))
        cls.cuda()
        z, (z_mu,z_sigma) = featurize(imgs, True, model, args)
        logits = cls(z)
        loss = F.cross_entropy(logits, labels)
        obj = loss
        regCMI = torch.zeros_like(obj)
        regL2R = z.norm(dim=1).mean()

        r_sigma_softplus = F.softplus(r_sigma)
        r_mu = r_mu.to(labels.device)[labels]
        r_sigma = r_sigma_softplus.to(labels.device)[labels]
        z_mu_scaled = z_mu*C.to(z_mu.device)
        z_sigma_scaled = z_sigma*C.to(z_mu.device)
        regCMI = torch.log(r_sigma) - torch.log(z_sigma_scaled) + \
                (z_sigma_scaled**2+(z_mu_scaled-r_mu)**2)/(2*r_sigma**2) - 0.5
        regCMI = regCMI.sum(1).mean()

        obj = obj + args.L2R_coeff*regL2R + args.CMI_coeff*regCMI

        optimzier.zero_grad()
        obj.backward()
        optimzier.step()

        metric.update(logits, labels)
    
    # log_ten.add_scalar(f'{site_name}_train_acc', metric.results()['acc'], epochs)
    scheduler.step()
    
def site_train(comm_rounds, site_name, args, model, optimizer, scheduler, dataloader, metric):
    tbar = tqdm(range(args.local_epochs))
    for local_epoch in tbar:
        # print("\nlocal epoch train: {}".format(local_epoch))
        tbar.set_description(f'{site_name}_train')
        epoch_site_train(comm_rounds*args.local_epochs + local_epoch, site_name, model, optimizer, scheduler, dataloader, metric, args)

def site_evaluation(epochs, site_name, args, model, dataloader, log_file, metric, note='after_fed'):
    model.eval()
    with torch.no_grad():
        for imgs, labels, domain_labels, in dataloader:
            imgs = imgs.cuda()
            labels = labels.cuda()
            cls = nn.Linear(args.z_dim, args.nbc)
            cls.cuda()
            z, (z_mu,z_sigma) = featurize(imgs, True, model, args)
            logits = cls(z)
            metric.update(logits, labels)

    results_dict = metric.results()
    # log_ten.add_scalar(f'{note}_{site_name}_loss', results_dict['loss'], epochs)
    # log_ten.add_scalar(f'{note}_{site_name}_acc', results_dict['acc'], epochs)
    log_file.info(f'{note} Round: {epochs:3d} | Epochs: {args.local_epochs*epochs:3d} | Domain: {site_name} | loss: {results_dict["loss"]:.4f} | Acc: {results_dict["acc"]*100:.2f}%')

    return results_dict

def site_evaluation_class_level(epochs, site_name, args, model, dataloader, log_file, log_ten, metric, note='after_fed'):
    model.eval()
    with torch.no_grad():
        for imgs, labels, domain_labels, in dataloader:
            imgs = imgs.cuda()
            output = model(imgs)
            metric.update(output, labels)
    results_dict = metric.results()
    log_ten.add_scalar(f'{note}_{site_name}_loss', results_dict['loss'], epochs)
    log_ten.add_scalar(f'{note}_{site_name}_acc', results_dict['acc'], epochs)
    log_ten.add_scalar(f'{note}_{site_name}_class_acc', results_dict['class_level_acc'], epochs)
    log_file.info(f'{note} Round: {epochs:3d} | Epochs: {args.local_epochs*epochs:3d} | Domain: {site_name} | loss: {results_dict["loss"]:.4f} | Acc: {results_dict["acc"]*100:.2f}% | C Acc: {results_dict["class_level_acc"]*100:.2f}%')

    return results_dict

def site_only_evaluation(model, dataloader, metric):
    model.eval()
    with torch.no_grad():
        for imgs, labels, domain_labels, in dataloader:
            imgs = imgs.cuda()
            output = model(imgs)
            metric.update(output, labels)
    results_dict = metric.results()
    return results_dict

def GetFedModel(args, num_classes, is_train=True):
    global_model, feature_level = GetNetwork(args, args.num_classes, is_train)
    global_model = global_model.cuda()
    # global_model = global_model.cpu()
    model_dict = {}
    optimizer_dict = {}
    scheduler_dict = {}
    
    if args.dataset == 'vlcs':
        domain_list = vlcs_domain_list
    if args.dataset == 'pacs':
        domain_list = pacs_domain_list
    elif args.dataset == 'officehome':
        domain_list = officehome_domain_list
    elif args.dataset == 'domainNet':
        domain_list = domainNet_domain_list
    elif args.dataset == 'terrainc':
        domain_list = terra_incognita_list
        
    for domain_name in domain_list:
        model_dict[domain_name], _ = GetNetwork(args, num_classes, is_train)
        model_dict[domain_name] = model_dict[domain_name].cuda()
        # model_dict[domain_name] = model_dict[domain_name].cpu()
        optimizer_dict[domain_name] = torch.optim.SGD(model_dict[domain_name].parameters(), lr=args.lr, momentum=0.9,
                                                      weight_decay=5e-4)
        total_epochs = args.local_epochs * args.comm
        if args.lr_policy == 'step':
            scheduler_dict[domain_name] = torch.optim.lr_scheduler.StepLR(optimizer_dict[domain_name], step_size=int(total_epochs *0.8), gamma=0.1)
        elif args.lr_policy == 'mul_step':
            scheduler_dict[domain_name] = torch.optim.lr_scheduler.MultiStepLR(optimizer_dict[domain_name], milestones=[int(total_epochs*0.3), int(total_epochs*0.8)], gamma=0.1)
        elif args.lr_policy == 'exp95':
            scheduler_dict[domain_name] = torch.optim.lr_scheduler.ExponentialLR(optimizer_dict[domain_name], gamma=0.95)
        elif args.lr_policy == 'exp98':
            scheduler_dict[domain_name] = torch.optim.lr_scheduler.ExponentialLR(optimizer_dict[domain_name], gamma=0.98)
        elif args.lr_policy == 'exp99':
            scheduler_dict[domain_name] = torch.optim.lr_scheduler.ExponentialLR(optimizer_dict[domain_name], gamma=0.99)   
        elif args.lr_policy == 'cos':
            scheduler_dict[domain_name] = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_dict[domain_name], T_max=total_epochs)
            
    return global_model, model_dict, optimizer_dict, scheduler_dict

def SaveCheckPoint(args, model, epochs, path, optimizer=None, schedule=None, note='best_val'):
    check_dict = {'args':args, 'epochs':epochs, 'model':model.state_dict(), 'note': note}
    if optimizer is not None:
        check_dict['optimizer'] = optimizer.state_dict()
    if schedule is not None:
        check_dict['shceduler'] = schedule.state_dict()
    if not os.path.isdir(path):
        os.makedirs(path)
        
    torch.save(check_dict, os.path.join(path, note+'.pt'))
    

def featurize(x, return_dist, model, args):
    num_samples=1
    z_params = model(x)
    # print(model)
    # print(z_params.shape)
    # print(args.z_dim)
    z_mu = z_params[:,:args.z_dim]
    z_sigma = F.softplus(z_params[:,args.z_dim:])
    # print(z_mu.shape)
    # print(z_sigma.shape)
    z_dist = distributions.Independent(distributions.normal.Normal(z_mu,z_sigma),1)
    z = z_dist.rsample([num_samples]).view([-1,args.z_dim])
    
    if return_dist:
        return z, (z_mu,z_sigma)
    else:
        return z