import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import tqdm.auto as auto


class fedtrainer:
    def __init__(
        self, 
        mu=1e-4, 
        eta=1e-3,
        f=1,
        gamma=0.9995,
        momentum=0.9,
        weight_decay=0.01,
        slow_weight=0.01,
        k=1,
        model=None, 
        dltrain=None,
        dlvalid=None,
        loss_func=nn.CrossEntropyLoss(),
        classifier_only=True,
        optimizer_name=None,
        onebit=False,
        device='cpu',
    ):
        super().__init__()
        
        self.mu = mu
        self.eta = eta
        self.gamma = gamma
        self.momentum = momentum
        self.onebit = onebit
        self.weight_decay = weight_decay
        self.slow_weight = slow_weight
        self.k = k
        self.f = f
        self.model = model.to(device)
        self.dltrain = dltrain
        self.dlvalid = dlvalid
        self.loss_func = loss_func
        self.optimizer_name = optimizer_name
        self.len_dltrain = len(self.dltrain[0])
        self.len_dlvalid = len(self.dlvalid)
        
        self.n_normal = 3
        self.n_slow = 2
        if classifier_only:
            to_be_optim = [v for k, v in self.model.named_parameters() if k.__contains__('classifier')]
            self.optim_keys = [k[0] for k in self.model.named_parameters() if k.__contains__('classifier')]
        else:
            to_be_optim = [v for k, v in self.model.named_parameters() if not k.__contains__('embed')]
            self.optim_keys = [k[0] for k in self.model.named_parameters() if not k.__contains__('embed')]
            
        print('The following parameters will be optimized:')
        # print(self.optim_keys)
        for k in self.optim_keys:
            print(k)
            
        if optimizer_name is None:
            print('Using no optimizer')
        elif optimizer_name == 'SGD':
            print('Using SGD optimizer')
            self.optimizer = optim.SGD(
                to_be_optim,
                lr=self.eta,
                weight_decay=self.weight_decay,
                momentum=self.momentum,
            )
        elif optimizer_name == 'Adam':
            print('Using Adam optimizer')
            self.optimizer = optim.Adam(
                to_be_optim,
                lr=self.eta,
                weight_decay=self.weight_decay,
            )
        elif optimizer_name == 'AdamW':
            print('Using AdamW optimizer')
            self.optimizer = optim.AdamW(
                to_be_optim,
                lr=self.eta,
                weight_decay=weight_decay,
            )
        else:
            raise NotImplementedError
        
        
        if isinstance(device, int):
            device = torch.device(f'cuda:{device}')
        self.device = device
        
            
    def seed_perturb(self, seed, scale, mask=None):
        torch.manual_seed(seed)
        
        for k, v in self.model.named_parameters():
            if k in self.optim_keys:
                dv = torch.randn_like(v).to(v.device)
                if mask is not None:
                    v.data += dv * self.mu * scale * mask[k]
                else:
                    v.data += dv * self.mu * scale
                    
                    
    def seed_grad_onebit(self, seedlist, directionlist, mask=None):
        assert self.optimizer_name is not None, 'Must use an optimizer, non-optimizer version not implemented'
        self.optimizer.zero_grad()
        for i in range(self.k):
            torch.manual_seed(seedlist[i])
            
            for k, v in self.model.named_parameters():
                if k in self.optim_keys:
                    dv = torch.randn_like(v).to(v.device)
                    thisdirection = sum(directionlist[i :: self.k])
                    thisdirection = 1 if thisdirection > 0 else -1
                    if mask is not None:
                        if v.grad is None:
                            v.grad = dv * thisdirection * mask[k] / self.k * self.eta
                        else:
                            v.grad += dv * thisdirection * mask[k] / self.k * self.eta
                    else:
                        if v.grad is None:
                            v.grad = dv * thisdirection / self.k * self.eta
                        else:
                            v.grad += dv * thisdirection / self.k * self.eta
        self.optimizer.step()
                    
                
                
    def seed_grad(self, seedlist, directionlist, mask=None):
        l = len(seedlist)
        
        if self.optimizer_name is not None:
            self.optimizer.zero_grad()
            for i in range(l):
                torch.manual_seed(seedlist[i])
                
                for k, v in self.model.named_parameters():
                    if k in self.optim_keys:
                        dv = torch.randn_like(v).to(v.device)
                        if mask is not None:
                            if v.grad is None:
                                v.grad = dv * directionlist[i] * mask[k] / l * self.eta
                            else:
                                v.grad += dv * directionlist[i] * mask[k] / l * self.eta
                        else:
                            if v.grad is None:
                                v.grad = dv * directionlist[i] / l * self.eta
                            else:
                                v.grad += dv * directionlist[i] / l * self.eta
            self.optimizer.step()
            
        else:
            for i in range(l):
                torch.manual_seed(seedlist[i])
                
                for k, v in self.model.named_parameters():
                    if k in self.optim_keys:
                        dv = torch.randn_like(v).to(v.device)
                        if 'bias' not in k and 'layer_norm' not in k and 'layernorm' not in k:
                            v.data -= v.data * self.weight_decay * self.eta
                        if mask is not None:
                            v.data -= dv * directionlist[i] * mask[k] / l * self.eta
                        else:
                            v.data -= dv * directionlist[i] / l * self.eta
                        
    
    def logger_init(self):
        self.losslist = []
        self.acclist = []
        self.modelist = []
        self.trainlist = []
        self.epochlist = []
        self.lrlist = []
        self.elapsed_steplist = []
        self.elapsed_timelist = []
        self.pgradlist = []
        self.seed = -1
        self.elapsed_step = 0
        self.elapsed_time = 0
        
    
    def logger_log(self, loss, acc, mode, train, epoch, lr, pgrad):
        self.losslist.append(loss)
        self.acclist.append(acc)
        self.modelist.append(mode)
        self.trainlist.append(train)
        self.epochlist.append(epoch)
        self.elapsed_steplist.append(self.elapsed_step)
        self.elapsed_timelist.append(self.elapsed_time)
        self.pgradlist.append(pgrad)
        self.lrlist.append(lr)
        
        
    def logger_summary(self):
        df = pd.DataFrame.from_dict({
            'loss': self.losslist, 
            'acc': self.acclist,
            'mode': self.modelist,
            'train': self.trainlist,
            'epoch': self.epochlist,
            'lr': self.lrlist,
            'time': self.elapsed_timelist,
            'step': self.elapsed_steplist,
            'pgrad': self.pgradlist,
        })
        return df
    
    
    def get_metric(self, x):
        fx = self.model(x['pixel_values'].to(self.device))
        y = x['labels'].to(self.device)
        acc = (torch.argmax(fx.logits, -1) == y).sum() / len(y)
        loss = self.loss_func(fx.logits, y)
        return loss, acc
    
    
    def get_direction(self, _loss, loss_):
        return (loss_.item() - _loss.item()) / self.mu / 2
    
    
    def get_mask(self):
        self.mask = {}
        for k, v in self.model.named_parameters():
            self.mask[k] = v.abs() > torch.quantile(v.abs(), self.quant)
            

    @torch.no_grad()
    def zo_epoch_train_proposed(self, epoch):
        pbar = auto.tqdm(enumerate(range(self.len_dltrain // (self.f * self.k))), total=self.len_dltrain // (self.f * self.k))
        enum = [enumerate(_) for _ in self.dltrain]
        for i in pbar:
            
            self.elapsed_time += 1
            normal_seedlist = []
            normal_directionlist = []
            slow_seedlist = []
            slow_directionlist = []
            
            for i_slow in range(self.n_slow):
                for q in range(self.k):
                    self.seed += 1
                    j, x = next(enum[i_slow])
                    
                    self.seed_perturb(self.seed, -1)
                    _loss, _acc = self.get_metric(x)
                    self.seed_perturb(self.seed, 2)
                    loss_, acc_ = self.get_metric(x)
                    self.seed_perturb(self.seed, -1)
                    direction = self.get_direction(_loss, loss_)
                    
                    pbar.desc = 'train, epoch, %2d, loss, %2.4f, acc, %2.4f' % (epoch, loss_.item(), acc_.item())
                    lr = self.eta
                    self.elapsed_step += 1
                    self.logger_log(_loss.item(), _acc.item(), 'zo', 'yes', epoch, lr, direction)
                    
                    slow_seedlist.append(self.seed)
                    slow_directionlist.append(direction)
        
            for f in range(self.f):
                for i_normal in range(self.n_normal):    
                    for q in range(self.k):
                        self.seed += 1
                        j, x = next(enum[i_normal + self.n_slow])
                        
                        self.seed_perturb(self.seed, -1)
                        _loss, _acc = self.get_metric(x)
                        self.seed_perturb(self.seed, 2)
                        loss_, acc_ = self.get_metric(x)
                        self.seed_perturb(self.seed, -1)
                        direction = self.get_direction(_loss, loss_)
                        pbar.desc = 'train, epoch, %2d, loss, %2.4f, acc, %2.4f' % (epoch, loss_.item(), acc_.item())
                        lr = self.eta
                        self.elapsed_step += 1
                        self.logger_log(_loss.item(), _acc.item(), 'zo', 'yes', epoch, lr, direction)
                        
                        normal_seedlist.append(self.seed)
                        normal_directionlist.append(direction)
                
                self.seed_grad(normal_seedlist, normal_directionlist)
                
            slow_directionlist = [_ * self.slow_weight for _ in slow_directionlist]
            self.seed_grad(slow_seedlist, slow_directionlist)
            
            
    @torch.no_grad()
    def zo_epoch_train_binary(self, epoch):
        assert self.optimizer_name is not None, 'Must use optimizer to do binary training'
        pbar = auto.tqdm(enumerate(range(self.len_dltrain // self.k)), total=self.len_dltrain // self.k)
        enum  = [enumerate(_) for _ in self.dltrain]
        for i in pbar:
            
            
            self.elapsed_time += 1
            seedlist = []
            directionlist = []
            
            for i_slow in range(self.n_slow):
                for q in range(self.k):
                    j, x = next(enum[i_slow])
                    self.seed += 1
                    self.seed_perturb(self.seed, -1)
                    _loss, _acc = self.get_metric(x)
                    self.seed_perturb(self.seed, 2)
                    loss_, acc_ = self.get_metric(x)
                    self.seed_perturb(self.seed, -1)
                    # direction = 1 * (self.get_direction(_loss, loss_) > 0)
                    
                    # try this:
                    direction = self.get_direction(_loss, loss_)
                    direction = 1 if direction > 0 else -1
                    
                    pbar.desc = 'train, epoch, %2d, loss, %2.4f, acc, %2.4f' % (epoch, loss_.item(), acc_.item())
                    lr = self.eta
                    self.elapsed_step += 1
                    self.logger_log(_loss.item(), _acc.item(), 'zo', 'yes', epoch, lr, direction)
                    seedlist.append(self.seed)
                    directionlist.append(direction)
                self.seed -= self.k
            
            for i_normal in range(self.n_normal):    
                for q in range(self.k):
                    j, x = next(enum[i_normal + self.n_slow])
                    self.seed += 1
                    self.seed_perturb(self.seed, -1)
                    _loss, _acc = self.get_metric(x)
                    self.seed_perturb(self.seed, 2)
                    loss_, acc_ = self.get_metric(x)
                    self.seed_perturb(self.seed, -1)
                    # direction = 1 * (self.get_direction(_loss, loss_) > 0)
                    
                    # try this:
                    direction = self.get_direction(_loss, loss_)
                    direction = 1 if direction > 0 else -1

                    pbar.desc = 'train, epoch, %2d, loss, %2.4f, acc, %2.4f' % (epoch, loss_.item(), acc_.item())
                    lr = self.eta
                    self.elapsed_step += 1
                    self.logger_log(_loss.item(), _acc.item(), 'zo', 'yes', epoch, lr, direction)
                    seedlist.append(self.seed)
                    directionlist.append(direction)
                self.seed -= self.k
            
            self.seed += self.k
            if self.onebit:
                self.seed_grad_onebit(seedlist, directionlist)
            else:
                self.seed_grad(seedlist, directionlist)
            
    
    @torch.no_grad()
    def zo_epoch_train_baseline(self, epoch):
        pbar = auto.tqdm(enumerate(range(self.len_dltrain // self.k)), total=self.len_dltrain // self.k)
        enum = [enumerate(_) for _ in self.dltrain]
        for i in pbar:
            
            self.elapsed_time += 1
            seedlist = []
            directionlist = []
            
            for i_slow in range(self.n_slow):
                for q in range(self.k):
                    self.seed += 1
                    j, x = next(enum[i_slow])
                    
                    self.seed_perturb(self.seed, -1)
                    _loss, _acc = self.get_metric(x)
                    self.seed_perturb(self.seed, 2)
                    loss_, acc_ = self.get_metric(x)
                    self.seed_perturb(self.seed, -1)
                    direction = self.get_direction(_loss, loss_)
                    pbar.desc = 'train, epoch, %2d, loss, %2.4f, acc, %2.4f' % (epoch, loss_.item(), acc_.item())
                    lr = self.eta
                    self.elapsed_step += 1
                    self.logger_log(_loss.item(), _acc.item(), 'zo', 'yes', epoch, lr, direction)
                    seedlist.append(self.seed)
                    directionlist.append(direction)
            
            for i_normal in range(self.n_normal):    
                for q in range(self.k):
                    self.seed += 1
                    j, x = next(enum[i_normal + self.n_slow])
                    
                    self.seed_perturb(self.seed, -1)
                    _loss, _acc = self.get_metric(x)
                    self.seed_perturb(self.seed, 2)
                    loss_, acc_ = self.get_metric(x)
                    self.seed_perturb(self.seed, -1)
                    direction = self.get_direction(_loss, loss_)
                    pbar.desc = 'train, epoch, %2d, loss, %2.4f, acc, %2.4f' % (epoch, loss_.item(), acc_.item())
                    lr = self.eta
                    self.elapsed_step += 1
                    self.logger_log(_loss.item(), _acc.item(), 'zo', 'yes', epoch, lr, direction)
                    seedlist.append(self.seed)
                    directionlist.append(direction)
                        
            self.seed_grad(seedlist, directionlist)
            
            
    def epoch_valid(self, epoch, code='zo'):
        pbar = auto.tqdm(enumerate(self.dlvalid), total=self.len_dlvalid)
        for i, x in pbar:
            
            _loss, _acc = self.get_metric(x)
            pbar.desc = 'valid, epoch, %2d, loss, %2.4f, acc, %2.4f' % (epoch, _loss.item(), _acc.item())
            lr = self.eta
            self.logger_log(_loss.item(), _acc.item(), code, 'no', epoch, lr, 0)