import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import tqdm.auto as auto

from dataclasses import dataclass
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F



def forward_wrap_with_option_len(self, input_ids=None, labels=None, option_len=None, num_options=None, return_dict=None, **kwargs):
    """
    This is to replace the original forward function of Transformer models to enable:
    (1) Partial target sequence: loss will only be calculated on part of the sequence
    (2) Classification-style training: a classification loss (CE) will be calculated over several options
    Input:
    - input_ids, labels: same as the original forward function
    - option_len: a list of int indicating the option lengths, and loss will be calculated only on the
    last option_len tokens 
    - num_options: a list of int indicating the number of options for each example (this will be #label
    words for classification tasks and #choices for multiple choice tasks), and a classification loss
    will be calculated.
    """
    outputs = self.original_forward(input_ids=input_ids, **kwargs)
    if labels is None:
        return outputs
    logits = outputs.logits

    # get configs from the model
    if isinstance(self, nn.DataParallel):
        pad_token_id = self.module.config.pad_token_id
        vocab_size = self.module.config.vocab_size
    else:
        pad_token_id = self.config.pad_token_id
        vocab_size = self.config.vocab_size

    loss = None
    # Shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous()
    # Here we use input_ids (which should always = labels) bc sometimes labels are correct candidate IDs
    shift_labels = torch.clone(input_ids)[..., 1:].contiguous()
    shift_labels[shift_labels == pad_token_id] = -100

    # Apply option len (do not calculate loss on the non-option part)
    for _i, _len in enumerate(option_len):
        shift_labels[_i, :-_len] = -100
        
        
    ################## classifier logits #####################
    accs = []
    ##########################################################

    # Calculate the loss
    loss_fct = CrossEntropyLoss(ignore_index=-100)
    if num_options is not None: 
        # Train as a classification tasks
        log_probs = F.log_softmax(shift_logits, dim=-1)
        mask = shift_labels != -100 # Option part
        shift_labels[~mask] = 0 # So that it doesn't mess up with indexing

        selected_log_probs = torch.gather(log_probs, dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1) # (bsz x num_options, len)
        selected_log_probs = (selected_log_probs * mask).sum(-1) / mask.sum(-1) # (bsz x num_options)

        if any([x != num_options[0] for x in num_options]):
            # Multi choice tasks with different number of options
            loss = 0
            start_id = 0
            count = 0
            while start_id < len(num_options):
                end_id = start_id + num_options[start_id]
                _logits = selected_log_probs[start_id:end_id].unsqueeze(0) # (1, num_options)
                _labels = labels[start_id:end_id][0].unsqueeze(0) # (1)
                loss = loss_fct(_logits, _labels) + loss
                accs.append((_logits.argmax() == _labels).sum() / len(_labels))
                count += 1
                start_id = end_id
            loss = loss / count
        else:
            num_options = num_options[0]
            selected_log_probs = selected_log_probs.view(-1, num_options) # (bsz, num_options)
            labels = labels.view(-1, num_options)[:, 0] # Labels repeat so we only take the first one
            loss = loss_fct(selected_log_probs, labels)
            acc = (selected_log_probs.argmax(-1) == labels).sum() / len(labels)
            accs.append(acc)
    else:
        shift_logits = shift_logits.view(-1, vocab_size)
        shift_labels = shift_labels.view(-1)
        loss = loss_fct(shift_logits, shift_labels)
        accs = [(shift_logits.argmax(-1).sum() == shift_labels).sum() / len(shift_labels)]

    if not return_dict:
        output = (logits,) + outputs[1:]
        # return (loss,) + output if loss is not None else output
        return (loss, accs)

    return CausalLMOutputWithPast(
        loss=loss,
        # clogits=clogits,
        logits=logits,
        # logits=clogits,
        past_key_values=outputs.past_key_values,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )


class fedtrainer:
    def __init__(
        self, 
        mu=1e-4, 
        eta=1e-3,
        gamma=0.9995,
        momentum=0.9,
        weight_decay=0.01,
        slow_weight=0.01,
        k=1,
        K=5,
        model=None, 
        dltrain=None,
        dlvalid=None,
        loss_func=nn.CrossEntropyLoss(),
        classifier_only=False,
        optimizer_name=None,
        onebit=False,
        cuda_devices=[0],
        nlp=False,
        byzantine=0,
    ):
        super().__init__()
        
        self.mu = mu
        self.eta = eta
        self.gamma = gamma
        self.momentum = momentum
        self.onebit = onebit
        self.weight_decay = weight_decay
        self.slow_weight = slow_weight
        self.k = k
        self.K = K
        # Multi-GPU support
        if len(cuda_devices) >= 1:
            device_name = "cuda:" + str(cuda_devices[0])
        else:
            device_name = "cpu"
        device = torch.device(device_name)
        self.device = device
        
        model = model.to(device)
        if len(cuda_devices) > 1:
            model = nn.DataParallel(model, device_ids=cuda_devices)
            self.is_cuda_parallel = True
        else:
            self.is_cuda_parallel = False
        self.model = model
        
        self.dltrain = dltrain
        self.dlvalid = dlvalid
        self.loss_func = loss_func
        self.optimizer_name = optimizer_name
        self.len_dltrain = len(self.dltrain[0])
        self.len_dlvalid = len(self.dlvalid)
        self.byzantine = byzantine
        
        if nlp:
            self.model.original_forward = self.model.forward
            self.model.forward = forward_wrap_with_option_len.__get__(self.model, type(self.model))

        if classifier_only:
            to_be_optim = [v for k, v in self.model.named_parameters() if k.__contains__('classifier')]
            self.optim_keys = [k for k, v in self.model.named_parameters() if k.__contains__('classifier')]
        else:
            to_be_optim = [v for k, v in self.model.named_parameters() if not k.__contains__('embed')]
            self.optim_keys = [k for k, v in self.model.named_parameters() if not k.__contains__('embed')]
            
        print('The following parameters will be optimized:')
        # print(self.optim_keys)
        for k in self.optim_keys:
            print(k)
            
        if optimizer_name is None:
            print('Using no optimizer')
        elif optimizer_name == 'SGD':
            print('Using SGD optimizer')
            self.optimizer = optim.SGD(
                to_be_optim,
                lr=self.eta,
                weight_decay=self.weight_decay,
                momentum=self.momentum,
            )
        elif optimizer_name == 'Adam':
            print('Using Adam optimizer')
            self.optimizer = optim.Adam(
                to_be_optim,
                lr=self.eta,
                weight_decay=self.weight_decay,
            )
        elif optimizer_name == 'AdamW':
            print('Using AdamW optimizer')
            self.optimizer = optim.AdamW(
                to_be_optim,
                lr=self.eta,
                weight_decay=weight_decay,
            )
        else:
            raise NotImplementedError
        
            
    def seed_perturb(self, seed, scale, mask=None):
        torch.manual_seed(seed)
        
        for k, v in self.model.named_parameters():
            if k in self.optim_keys:
                dv = torch.randn_like(v).to(v.device)
                if mask is not None:
                    v.data += dv * self.mu * scale * mask[k]
                else:
                    v.data += dv * self.mu * scale
                    
                    
    def seed_grad_onebit(self, seedlist, directionlist, mask=None):
        assert self.optimizer_name is not None, 'Must use an optimizer, non-optimizer version not implemented'
        self.optimizer.zero_grad()
        for i in range(self.K):
            torch.manual_seed(seedlist[i])
            
            for k, v in self.model.named_parameters():
                if k in self.optim_keys:
                    dv = torch.randn_like(v).to(v.device)
                    thisdirection = sum(directionlist[i :: self.k])
                    thisdirection = 1 if thisdirection > 0 else -1
                    if mask is not None:
                        if v.grad is None:
                            v.grad = dv * thisdirection * mask[k] / self.k * self.eta
                        else:
                            v.grad += dv * thisdirection * mask[k] / self.k * self.eta
                    else:
                        if v.grad is None:
                            v.grad = dv * thisdirection / self.k * self.eta
                        else:
                            v.grad += dv * thisdirection / self.k * self.eta
        self.optimizer.step()
                    
                
                
    def seed_grad(self, seedlist, directionlist, mask=None):
        l = len(seedlist)
        
        if self.optimizer_name is not None:
            self.optimizer.zero_grad()
            for i in range(l):
                torch.manual_seed(seedlist[i])
                
                for k, v in self.model.named_parameters():
                    if k in self.optim_keys:
                        dv = torch.randn_like(v).to(v.device)
                        if mask is not None:
                            if v.grad is None:
                                v.grad = dv * directionlist[i] * mask[k] / l * self.eta
                            else:
                                v.grad += dv * directionlist[i] * mask[k] / l * self.eta
                        else:
                            if v.grad is None:
                                v.grad = dv * directionlist[i] / l * self.eta
                            else:
                                v.grad += dv * directionlist[i] / l * self.eta
            self.optimizer.step()
            
        else:
            for i in range(l):
                torch.manual_seed(seedlist[i])
                
                for k, v in self.model.named_parameters():
                    if k in self.optim_keys:
                        dv = torch.randn_like(v).to(v.device)
                        if 'bias' not in k and 'layer_norm' not in k and 'layernorm' not in k:
                            v.data -= v.data * self.weight_decay * self.eta
                        if mask is not None:
                            v.data -= dv * directionlist[i] * mask[k] / l * self.eta
                        else:
                            v.data -= dv * directionlist[i] / l * self.eta
                        
    
    def logger_init(self):
        self.losslist = []
        self.acclist = []
        self.modelist = []
        self.trainlist = []
        self.epochlist = []
        self.lrlist = []
        self.elapsed_steplist = []
        self.elapsed_timelist = []
        self.pgradlist = []
        self.seed = -1
        self.elapsed_step = 0
        self.elapsed_time = 0
        
    
    def logger_log(self, loss, acc, mode, train, epoch, lr, pgrad):
        self.losslist.append(loss)
        self.acclist.append(acc)
        self.modelist.append(mode)
        self.trainlist.append(train)
        self.epochlist.append(epoch)
        self.elapsed_steplist.append(self.elapsed_step)
        self.elapsed_timelist.append(self.elapsed_time)
        self.pgradlist.append(pgrad)
        self.lrlist.append(lr)
        
        
    def logger_summary(self):
        df = pd.DataFrame.from_dict({
            'loss': self.losslist, 
            'acc': self.acclist,
            'mode': self.modelist,
            'train': self.trainlist,
            'epoch': self.epochlist,
            'lr': self.lrlist,
            'time': self.elapsed_timelist,
            'step': self.elapsed_steplist,
            'pgrad': self.pgradlist,
        })
        return df
    
    
    def get_metric_(self, x):
        fx = self.model(x['pixel_values'].to(self.device))
        y = x['labels'].to(self.device)
        acc = (torch.argmax(fx.logits, -1) == y).sum() / len(y)
        loss = self.loss_func(fx.logits, y)
        return loss, acc
    
    
    def get_metric(self, x):
        values_of_x = list(x.values())
        values_of_x = [_.to(self.device) for _ in values_of_x]
        keys_of_x = list(x.keys())
        x = dict(zip(keys_of_x, values_of_x))
        fx = self.model(**x)
        
        if isinstance(fx, tuple):
            loss = fx[0].detach()
            acc = sum(fx[1]).detach() / len(fx[1])
        else:
            loss = fx.loss.detach()
            acc = (fx.logits.argmax(-1) == x['labels']).sum() / torch.tensor(x['labels'].shape).prod()
            
        if self.is_cuda_parallel:
            loss = loss.mean()
            acc = acc.mean()
        return loss, acc
    
    
    def get_direction(self, _loss, loss_):
        return (loss_.item() - _loss.item()) / self.mu / 2
    
    
    def get_mask(self):
        self.mask = {}
        for k, v in self.model.named_parameters():
            self.mask[k] = v.abs() > torch.quantile(v.abs(), self.quant)
            
            
    @torch.no_grad()
    def zo_epoch_train_binary(self, epoch):
        assert self.optimizer_name is not None, 'Must use optimizer to do binary training'
        # pbar = auto.tqdm(enumerate(range(self.len_dltrain // self.k // self.K)), total=self.len_dltrain // self.k // self.K)
        
        L = self.len_dltrain // self.k // self.K
        pbar = auto.tqdm(enumerate(range(max(1, L))), total=max(1, L))
        enum  = [enumerate(_) for _ in self.dltrain]
        for i in pbar:
            
            self.elapsed_time += 1
            seedlist = []
            directionlist = []
            
            for i in range(self.K):
                for q in range(self.k):
                    j, x = next(enum[i])
                    self.seed += 1
                    self.seed_perturb(self.seed, -1)
                    _loss, _acc = self.get_metric(x)
                    self.seed_perturb(self.seed, 2)
                    loss_, acc_ = self.get_metric(x)
                    self.seed_perturb(self.seed, -1)
                    # direction = 1 * (self.get_direction(_loss, loss_) > 0)
                    
                    # try this:
                    direction = self.get_direction(_loss, loss_)
                    direction = 1 if direction > 0 else -1
                    
                    pbar.desc = 'binary, train, epoch, %2d, loss, %2.4f, acc, %2.4f' % (epoch, loss_.item(), acc_.item())
                    lr = self.eta
                    self.elapsed_step += 1
                    self.logger_log(_loss.item(), _acc.item(), 'zo', 'yes', epoch, lr, direction)
                    
                    seedlist.append(self.seed)
                    if i < self.byzantine:
                        directionlist.append(- direction)
                    else:
                        directionlist.append(direction)
                    
                self.seed -= self.k
            
            self.seed += self.k
            if self.onebit:
                self.seed_grad_onebit(seedlist, directionlist)
            else:
                self.seed_grad(seedlist, directionlist)
            
    
    @torch.no_grad()
    def zo_epoch_train_baseline(self, epoch):
        L = self.len_dltrain // self.k // self.K
        pbar = auto.tqdm(enumerate(range(max(1, L))), total=max(1, L))
        enum = [enumerate(_) for _ in self.dltrain]
        for i in pbar:
            
            self.elapsed_time += 1
            seedlist = []
            directionlist = []
            
            for i in range(self.K):
                for q in range(self.k):
                    self.seed += 1
                    j, x = next(enum[i])
                    
                    self.seed_perturb(self.seed, -1)
                    _loss, _acc = self.get_metric(x)
                    self.seed_perturb(self.seed, 2)
                    loss_, acc_ = self.get_metric(x)
                    self.seed_perturb(self.seed, -1)
                    direction = self.get_direction(_loss, loss_)
                    pbar.desc = 'baseline, train, epoch, %2d, loss, %2.4f, acc, %2.4f' % (epoch, loss_.item(), acc_.item())
                    lr = self.eta
                    self.elapsed_step += 1
                    self.logger_log(_loss.item(), _acc.item(), 'zo', 'yes', epoch, lr, direction)
                    seedlist.append(self.seed)
                    
                    if i < self.byzantine:
                        directionlist.append(np.random.randn() * 31.622776601683793319988935444327)
                    else:
                        directionlist.append(direction)
                        
            self.seed_grad(seedlist, directionlist)
            
            
    @torch.no_grad()
    def zo_epoch_train_baseline_shorter(self, epoch):
        L = self.len_dltrain // self.K
        pbar = auto.tqdm(enumerate(range(max(1, L))), total=max(1, L))
        enum = [enumerate(_) for _ in self.dltrain]
        for i in pbar:
            
            self.elapsed_time += 1
            seedlist = []
            directionlist = []
            
            for i_ in range(self.K):
                j, x = next(enum[i_])
                for q in range(self.k):
                    self.seed += 1
                    
                    self.seed_perturb(self.seed, -1)
                    _loss, _acc = self.get_metric(x)
                    self.seed_perturb(self.seed, 2)
                    loss_, acc_ = self.get_metric(x)
                    self.seed_perturb(self.seed, -1)
                    direction = self.get_direction(_loss, loss_)
                    pbar.desc = 'baseline, train, epoch, %2d, loss, %2.4f, acc, %2.4f' % (epoch, loss_.item(), acc_.item())
                    lr = self.eta
                    self.elapsed_step += 1
                    self.logger_log(_loss.item(), _acc.item(), 'zo', 'yes', epoch, lr, direction)
                    seedlist.append(self.seed)
                    directionlist.append(direction)
                        
            self.seed_grad(seedlist, directionlist)
            
            
    def epoch_valid(self, epoch, code='zo'):
        pbar = auto.tqdm(enumerate(self.dlvalid), total=self.len_dlvalid)
        for i, x in pbar:
            
            _loss, _acc = self.get_metric(x)
            pbar.desc = 'valid, epoch, %2d, loss, %2.4f, acc, %2.4f' % (epoch, _loss.item(), _acc.item())
            lr = self.eta
            self.logger_log(_loss.item(), _acc.item(), code, 'no', epoch, lr, 0)