import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch.distributions as dist

import copy
import numpy as np
from collections import defaultdict, OrderedDict

from methods.base import *
from util import *

from src.utils.torch_utils import get_flat_grad, get_state_dict, get_flat_params_from, set_flat_params_to

from model_manual import *


class Model(Base):
    def __init__(self, args, train_base):
        # self.probabilistic = True
        super(Model, self).__init__(args, train_base)
        self.J_norm_coef = args.J_norm_coef
        self.J_ind_coef = args.J_ind_coef

    
    
    def train_client(self,loader, steps, RegPole, base_train):
        self.train()
        lossMeter = AverageMeter()
        accMeter = AverageMeter()
        regJ_ind_Meter = AverageMeter()
        regJ_norm_Meter = AverageMeter()

        for step in range(steps):
            x, y = next(iter(loader))
            x, y = x.cuda(), torch.tensor(y).cuda()
            logits = self.net(x)
            loss = F.cross_entropy(logits,y)
            
            obj = loss
            
            if RegPole is not None:
                regJ_ind = torch.zeros_like(obj)
                regJ_norm = torch.zeros_like(obj)
                
                # calculate local grads
                if self.J_ind_coef or self.J_norm_coef:
                    self.optim.zero_grad()
                    latest_model_local_grad = get_flat_grad(loss, self.net.parameters(), create_graph=True)
            
                if self.J_ind_coef != 0.0:
                    regJ_ind = torch.norm(RegPole - latest_model_local_grad) ** 2
                    obj  = obj + self.J_ind_coef * regJ_ind
                
                if self.J_norm_coef != 0.0:
                    regJ_ind = torch.norm(RegPole - latest_model_local_grad) ** 2
                    obj  = obj + self.J_ind_coef * regJ_ind

            self.optim.zero_grad()
            obj.backward()
            self.optim.step()          
            
            acc = (logits.argmax(1)==y).float().mean()
            lossMeter.update(loss.data,x.shape[0])
            accMeter.update(acc.data,x.shape[0])
            
            if RegPole is not None:
                regJ_ind_Meter.update(regJ_ind.data,x.shape[0])
                regJ_norm_Meter.update(regJ_norm.data,x.shape[0])
            
            
            if base_train:
                base_logits = self.base_net(x)
                base_loss = F.cross_entropy(base_logits,y)
                base_obj = base_loss
                
                self.base_optim.zero_grad()
                base_obj.backward()
                self.base_optim.step()  
        
        # determine Jacobians to send
        logits = self.net(x)
        loss = F.cross_entropy(logits,y)
        self.optim.zero_grad()
        latest_model_local_grad = get_flat_grad(loss, self.net.parameters(), create_graph=True)
        
        if base_train:
            base_logits = self.base_net(x)
            base_loss = F.cross_entropy(base_logits,y)
            self.base_optim.zero_grad()
            J_base = get_flat_grad(base_loss, self.base_net.parameters(), create_graph=True)
        
        
        stats = {'acc': accMeter.average(), 'loss': lossMeter.average(), 
                    'J_JNB': latest_model_local_grad.detach()}
        
        if RegPole is not None:
            stats.update({'regJ_ind': regJ_ind_Meter.average(), 'regJ_norm': regJ_norm_Meter.average()})
        if base_train: 
            stats.update({'norm_J_base': torch.norm(J_base.detach())})  
        return stats