import torch
import math
import tqdm
from algorithms.single_model_algorithm import SingleModelAlgorithm
from models.initializer import initialize_model
import numpy as np
import os
from torch.nn.utils import clip_grad_norm_
from utils import move_to
import math
from torch.nn.functional import one_hot, log_softmax

class BMIX(SingleModelAlgorithm):
    def __init__(self, config, d_out, grouper, loss, metric, n_train_steps):
        # initialize model
        if config.data_parallel:
            featurizer, classifier = initialize_model(config, d_out=d_out, is_featurizer=True)
            featurizer = featurizer.to(config.device)
            classifier = classifier.to(config.device)
            model = torch.nn.Sequential(featurizer, classifier).to(config.device)
            assert config.device == 'cuda'
            model = torch.nn.DataParallel(model)
        else:
            model = initialize_model(config, d_out).to(config.device)
        super().__init__(
            config=config,
            model=model,
            grouper=grouper,
            loss=loss,
            metric=metric,
            n_train_steps=n_train_steps,
        )
        if config.data_parallel:
            self.featurizer = featurizer
            self.classifier = classifier 
        self.seed = config.seed
        self.weight_type = "grad"#config.weight_type
        # hyperparameters
        ## sampling start epoch T_s

        # take in n_t, terminal_T, var as variable
        self.n_t = config.bmix_n_t
        self.terminal_T = config.bmix_terminal_T
        self.var = config.bmix_var
        self.sigma = config.umix_sigma
        self.num_of_classes = d_out
        self.config = config
        self.batch_size = config.batch_size
        self.device = config.device
        self.mixup_type = config.umix_mixup_type
        self.xi = config.bmix_xi
        self.lr = config.lr
        self.weight_decay = config.weight_decay
        self.weight_decay_gammas = config.weight_decay_gammas
        self.weight_decay_schedule = config.weight_decay_schedule
        self.simple_grad = config.simple_grad
        self.sub_sample_group = config.sub_sample_group
        self.noise = config.noise

    def update(self, batch, epoch):
        """
        Process the batch, update the log, and update the model
        Args:
            - batch (tuple of Tensors): a batch of data yielded by data loaders
        Output:
            - results (dictionary): information about the batch, such as:
                - g (Tensor)
                - y_true (Tensor)
                - metadata (Tensor)
                - outputs (Tensor)
                - y_pred (Tensor)
                - objective (float)
        """
        assert self.is_training
        # process batch
        self.adjust_weight_decay(epoch)
        results = self.process_batch(batch)
        self._update(results)
        # log results        
        self.update_log(results)
        return self.sanitize_dict(results)
    
        
    def adjust_weight_decay(self, epoch):
        """Sets the weight decay"""
        wd = self.weight_decay
        assert len(self.weight_decay_gammas) == len(self.weight_decay_schedule), "length of gammas and schedule should be equal"

        "warm up"
        if epoch <= self.weight_decay_schedule[0]: 
            lr = self.weight_decay_schedule[0] + epoch*(self.weight_decay/ self.weight_decay_schedule[0]) 
            #gammas[0] + (args.learning_rate-gammas[0]) /step[0]
        for (gamma, step) in zip(self.weight_decay_gammas[1:], self.weight_decay_schedule[1:]):
            if (epoch >= step):
                wd = wd + gamma
            else:
                break
        for param_group in self.optimizer.param_groups:
            param_group['weight_decay'] = wd
        return None
    
    def importance_weights(self, x, sample_grads, t, batch_size, n_t, simple=True):
        """Calcualte importance weight by integrating the sample_grads according to girsanov"""
        if simple:
            weight = (sample_grads * x).reshape(self.batch_size, n_t, -1).sum(-1) - \
                0.5*(sample_grads**2).reshape(self.batch_size, n_t, -1).sum(-1)
        else:
            if n_t > 1:
                x = x.reshape(batch_size, n_t, *x.shape[1:])
                sample_grads = sample_grads.reshape(batch_size, n_t, *sample_grads.shape[1:])
                dX = torch.diff(x, dim=1)
                dt = torch.diff(t, dim=1)
                weight = (dX.flatten(2)*sample_grads.flatten(2)[:,0:-1]).sum(-1) - \
                    1/2 * torch.sum(sample_grads.flatten(2)[:,0:-1]*dt, dim=-1)
            else:
                weight = sample_grads.flatten(2).sum(-1) - \
                    1/2 * torch.sum(sample_grads.flatten(2), dim=-1)
        return weight
    
    def evaluate_uncertainty(self, batch):
        x, y_true, metadata, idx = batch
        x = move_to(x, self.device)
        y_true = move_to(y_true, self.device)
        g = move_to(self.grouper.metadata_to_group(metadata), self.device)
        
        if self.mixup_type == "vanillamix":
                x = x.float()
                self.batch_size = x.shape[0]
                x, y_a, y_b, mix, weight_b = mix_up(x, y_true, self.alpha, self.device, weight, shuffle=False)
                outputs = self.model(x)
        elif self.mixup_type == "cutmix":
            x = x.float()
            self.batch_size = x.shape[0]
            x_bridge, y_bridge, _, _, _, _, t, n_t = cutmix_bridge(x, y_true, self.device, self.num_of_classes, 
                                                                           self.terminal_T,self.n_t,self.var, shuffle=False)
            x_bridge.requires_grad = True
            outputs = self.model(x_bridge)
            loss = self.loss.loss_fn(outputs, y_bridge).mean()
            sample_grads = torch.autograd.grad(loss, x_bridge, retain_graph=True)[0].detach()
            weight = self.importance_weights(x_bridge, sample_grads,t, self.batch_size, n_t, self.simple_grad)         
        elif self.mixup_type == "manifoldmix":
            x_f = self.featurizer(x).detach()
            self.batch_size = x_f.shape[0]
            x_bridge, y_bridge, _, _, _, _, t, n_t = bridgemix(x_f, y_true, self.device, self.num_of_classes, 
                                                                       self.terminal_T,self.n_t,self.var, shuffle=False)
            x_bridge.requires_grad = True
            outputs = self.classifier(x_bridge)
            loss = self.loss.loss_fn(log_softmax(outputs, dim=1), y_bridge).mean()
            sample_grads = torch.autograd.grad(loss, x_bridge, retain_graph=True)[0].detach()
            
        elif self.mixup_type == "bridgemix":
            x = x.float()
            x.requires_grad = True
            self.batch_size = x.shape[0]
            y_b_onehot = one_hot(y_true, self.num_of_classes)
            outputs = self.model(x)
            loss = self.loss.loss_fn(log_softmax(outputs, dim=1), 
                                     torch.ones_like(y_b_onehot)/self.num_of_classes).mean()
            sample_grads = torch.autograd.grad(loss, x, retain_graph=True)[0].detach()
            #weight = self.importance_weights(x, sample_grads,t, self.batch_size, n_t, self.simple_grad)
        weight = torch.norm(sample_grads.reshape(self.batch_size, -1), dim=-1)
        results = {
            'g': g,
            'y_a': None,
            'y_b': None,
            'y_true': y_true,
            'y_pred_bridge':outputs,
            'y_pred': outputs,
            'metadata': metadata,
            'mix':None,
            'weight':weight,
            'idx': None
        }
        return results
    
    
    def process_batch(self, batch,test=False):
        """
        A helper function for update() and evaluate() that processes the batch
        Args:
            - batch (tuple of Tensors): a batch of data yielded by data loaders
        Output:
            - results (dictionary): information about the batch
                - y_true (Tensor)
                - g (Tensor)
                - metadata (Tensor)
                - output (Tensor)
                - y_true
        """
        x, y_true, metadata, idx = batch
        x = move_to(x, self.device)
        y_true = move_to(y_true, self.device)
        g = move_to(self.grouper.metadata_to_group(metadata), self.device)
        # Add noise for robustness testing
        if not self.is_training and self.noise > 0 and test:
            x += torch.randn_like(x) * self.noise
        p = np.random.rand(1) 
        if self.training:
            if p <= self.sigma:
                if self.mixup_type == "vanillamix":
                    x = x.float()
                    self.batch_size = x.shape[0]
                    x, y_a, y_b, mix, weight_b = mix_up(x, y_true, self.alpha, self.device, weight)
                    outputs = self.model(x)
                elif self.mixup_type == "cutmix":
                    x = x.float()
                    self.batch_size = x.shape[0]
                    if self.sub_sample_group:
                        input_shuffle = subsample_group_endpoint(g)
                    else:
                        input_shuffle = None
                    x_bridge, y_bridge, y_a, y_b, mix, idx, t, n_t = cutmix_bridge(x, y_true, self.device, self.num_of_classes,
                                                                                   self.terminal_T,self.n_t,self.var)
                    x_bridge.requires_grad = True
                    outputs = self.model(x_bridge)
                    loss = self.loss.loss_fn(outputs, y_bridge).mean()
                    sample_grads = torch.autograd.grad(loss, x_bridge, retain_graph=True)[0]
                    weight = self.importance_weights(x_bridge, sample_grads,t, self.batch_size, n_t, self.simple_grad)
                elif self.mixup_type == "manifoldmix":
                    x_f = self.featurizer(x).detach()
                    self.batch_size = x_f.shape[0]
                    if self.sub_sample_group:
                        input_shuffle = subsample_group_endpoint(g)
                    else:
                        input_shuffle = None
                    x_bridge, y_bridge, y_a, y_b, mix, idx, t, n_t = bridgemix(x_f, y_true, self.device, self.num_of_classes, 
                                                                               self.terminal_T,self.n_t,self.var,
                                                                               input_shuffle=input_shuffle)
                    x_bridge.requires_grad = True
                    outputs = self.classifier(x_bridge)
                    loss = self.loss.loss_fn(outputs, y_bridge).mean()
                    sample_grads = torch.autograd.grad(loss, x_bridge, retain_graph=True)[0]
                    weight = self.importance_weights(x_bridge, sample_grads,t, self.batch_size, n_t, self.simple_grad)
                elif self.mixup_type == "bridgemix":
                    x = x.float()
                    x.requires_grad = True
                    self.batch_size = x.shape[0]
                    if self.sub_sample_group:
                        input_shuffle = subsample_group_endpoint(g)
                    else:
                        input_shuffle = None
                    x_bridge, y_bridge, y_a, y_b, mix, idx, t, n_t = bridgemix(x, y_true, self.device, self.num_of_classes,
                                                                               self.terminal_T,self.n_t,self.var,
                                                                               input_shuffle=input_shuffle)
                    outputs = self.model(x_bridge)
                    loss = self.loss.loss_fn(log_softmax(outputs, dim=1).reshape(self.batch_size, n_t, -1), 
                                             y_bridge.reshape(self.batch_size, n_t, -1)).mean()
                    sample_grads = torch.autograd.grad(loss, x_bridge, retain_graph=True, create_graph=True)[0]
                    weight = self.importance_weights(x_bridge, sample_grads,t, self.batch_size, n_t, self.simple_grad)
                results = {
                    'g': g,
                    'y_a': y_bridge,
                    'y_b': y_b,
                    'y_true': y_true,
                    'y_pred_bridge':outputs,
                    'y_pred': outputs.reshape(self.batch_size, n_t, -1)[:,0],
                    'metadata': metadata,
                    'mix':mix,
                    'weight':weight,
                    'idx': idx
                }
            else:
                x.requires_grad = True
                outputs = self.model(x)
                if self.mixup_type != "bridgemix":
                    weight = move_to(torch.from_numpy(self.trajectory.get_weight(idx)), self.device)
                else:
                    # weight is the vjp/gradient
                    # need to compute gradient wrt x
                    self.batch_size = x.shape[0]
                    if self.weight_type == "grad":
                        loss = self.loss.loss_fn(log_softmax(outputs, dim=1), one_hot(y_true, self.num_of_classes)).mean()
                        sample_grads = torch.autograd.grad(loss, x, retain_graph=True, create_graph=True)[0]
                        weight = (sample_grads * x).reshape(self.batch_size, -1).sum(-1) - \
                        0.5*(sample_grads**2).reshape(self.batch_size, -1).sum(-1)
                    elif self.weight_type == "vjp":
                        _, vjp = torch.autograd.functional.vjp(func=self.model, inputs=x, 
                        v=(torch.nn.functional.one_hot(y_true, outputs.shape[-1]) - outputs)**2**0.5)
                results = {
                    'g': g,
                    'y_true': y_true,
                    'y_pred': outputs,
                    'metadata': metadata,
                    'idx': idx,
                    'weight': weight
                    }
        elif not self.training:
            outputs = self.model(x)
            results = {
                'g': g,
                'y_true': y_true,
                'y_pred': outputs,
                'metadata': metadata,
                'idx': idx
            }
        return results

    def objective(self, results):
        # compute group losses
        element_wise_losses = self.loss.compute_element_wise(
            results['y_pred'][:self.batch_size],
            results['y_true'][:self.batch_size],
            return_dict=False) 
        loss = element_wise_losses.mean()
        return loss
    
    def _update(self, results):
        """
        Process the batch, update the log, and update the model, group weights, and scheduler.
        Args:
            - batch (tuple of Tensors): a batch of data yielded by data loaders
        Output:
            - results (dictionary): information about the batch, such as:
                - g (Tensor)
                - y_true (Tensor)
                - metadata (Tensor)
                - loss (Tensor)
                - metrics (Tensor)
                - objective (float)
        """
        if 'mix' in results:
            element_wise_losses = self.loss.compute_element_wise(log_softmax(results['y_pred_bridge'], dim=1), 
                                                                 results['y_a'], return_dict=False)
            objective = (element_wise_losses.reshape(self.batch_size, self.n_t) - self.xi*results['weight']).sum(1).mean()
        else:
            element_wise_losses = self.loss.compute_element_wise(log_softmax(results['y_pred'], dim=1), 
                                                                 one_hot(results['y_true']), return_dict=False)
            objective = (element_wise_losses - self.xi*results['weight']).mean()
        self.model.zero_grad()
        objective.backward()
        #if self.max_grad_norm:
        #    clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
        self.optimizer.step()
        self.step_schedulers(
            is_epoch=False,
            metrics=results,
            log_access=False)
        results['objective'] = objective.item()


class trajectory(object):
    def __init__(self, n_data, args):
        self.trajectoris = None #np.load(os.path.join(args.trajectory_path, "trajectory_"+str(args.seed)+".npy"), allow_pickle=True)
        self.n_data = n_data
        self.T_s = args.T_s
        self.T = args.T
        self.eta = args.eta

    def get_weight(self, idx):
        weight = (1-np.mean(self.trajectoris[idx, self.T_s : self.T_s + self.T], axis=1)) * self.eta + 1
        return weight


def mix_up(x, l, beta_param, device,  weight):
    """
    Args:
        x: the input image batch [batch_size, H, W, C]
        l: the label batch  [batch_size, num_of_class]
        v: mentornet weights
        beta_param: the parameter to sample the weight average.
    Returns:
        result: The mixed images and label batches.
    """
    batch_size = x.shape[0]
    idx = move_to(torch.randperm(batch_size), device)
    x_b = x[idx]
    l_b = l[idx]
    lam = np.random.beta(beta_param, beta_param)
    xmix = lam * x + (1 - lam) * x_b
    y_a = l
    y_b = l_b
    weight_b = weight[idx]
    return xmix, y_a, y_b, lam, weight_b

def brownian_bridge_ab(t, a, b, var=1, pp=True, simplex=False):
    '''
    Samples a Brownian Bridge from a to b.
    '''

    dt = t[:,1] - t[:,0]
    t = (t - t[:,0].unsqueeze(1)) / (t[:,-1] - t[:,0]).unsqueeze(1)

    if pp:
        pp = -(torch.rand_like(t) + 1e-4).log()
    else:
        pp = 1

    dW = torch.randn_like(t) * dt.sqrt().unsqueeze(1) * var * pp
    W = dW.cumsum(1)
    W[:,0] = 0
    W = W + a.unsqueeze(1)
    
    BB = W - t * (W[:,-1] - b).unsqueeze(1)
    if simplex:
        bridge_abs = BB.abs()
        BB = bridge_abs / bridge_abs.sum(-1, keepdims=True)
    return BB, t

def subsample_group_endpoint(group_labels):
    """
    A helper function that subsamples the endpoints
    Args:
        group_labels (_type_): group labels
    Returns:
        _type_: subsampled and shuffled index that balance group labels
    """
    unique, group_counts = torch.unique(group_labels, return_counts=True)
    inverse_index = group_counts.argsort(descending=False)
    count_sort = group_counts.sort(descending=True)[0]
    unique_sorted = unique[inverse_index]
    
    subsample_index = []
    for g_label,count in zip(unique_sorted, count_sort):
        index_temp = torch.where(group_labels == g_label)[0].detach().cpu()
        weights = torch.ones_like(index_temp).float() # create a tensor of weights
        # sample/shuffle these index with replacement to achieve up/downsampling
        subsample_index += index_temp[torch.multinomial(weights, count, replacement=True)].numpy().tolist()
    subsample_index = torch.Tensor(subsample_index).long()
    input_shuffle = torch.randperm(subsample_index.shape[0])
    subsample_index = subsample_index[input_shuffle]
    
    return subsample_index

def bridgemix(x, l, device, num_of_classes, terminal_T=1, n_t=5, var=0.0001, shuffle=True, input_shuffle=None):
    """_summary_

    Args:
        x: the input image batch [batch_size, H, W, C]
        l: the label batch  [batch_size, num_of_class]
        terminal_T: bridge sample terminal T
        n_t: number of bridge sample
        var: bridge diffusion coefficient

    Returns:
        result: The mixed images and label batches.
    """
    if shuffle:
        batch_size = x.shape[0]
        d = int(torch.prod(torch.tensor(x.shape[1:])))
        t = torch.linspace(0, terminal_T, n_t).unsqueeze(0).unsqueeze(-1).repeat(batch_size, 1, d).float().to(device)
        t_label = torch.linspace(0, terminal_T, n_t).unsqueeze(0).unsqueeze(-1).repeat(batch_size, 1, 1).float().to(device)
        idx = move_to(torch.randperm(batch_size), device) if input_shuffle is None else \
            move_to(input_shuffle, device)
        
        x_a = x.flatten(1)
        x_b = x[idx].flatten(1)
        
        l_a_onehot = one_hot(l, num_of_classes)
        l_b_onehot = one_hot(l[idx], num_of_classes)
        
        l_a = l.unsqueeze(1).repeat(1,t_label.shape[1],1).reshape(n_t,batch_size).T.flatten()
        l_b = l[idx].unsqueeze(1).repeat(1,t_label.shape[1],1).reshape(n_t,batch_size).T.flatten()
        bridge = (brownian_bridge_ab(t, x_a, x_b, var)[0]).reshape(-1, *x.shape[1:])
        y_bridge = (brownian_bridge_ab(t_label, l_a_onehot, l_b_onehot, var, simplex=True)[0]).reshape(-1, *l_a_onehot.shape[1:])
        lam = ((terminal_T - t_label[...,0])/terminal_T).reshape(-1)
        
        xmix = bridge
        ymix = y_bridge
        y_a = l_a
        y_b = l_b
    else:
        batch_size = x.shape[0]
        d = int(torch.prod(torch.tensor(x.shape[1:])))
        t = torch.linspace(0, terminal_T, n_t).unsqueeze(0).unsqueeze(-1).repeat(batch_size, 1, d).float().to(device)
        t_label = torch.linspace(0, terminal_T, n_t).unsqueeze(0).unsqueeze(-1).repeat(batch_size, 1, 1).float().to(device)
        idx = None
        
        x_a = x.flatten(1)
        x_b = x.flatten(1)
        
        l_a_onehot = one_hot(l, num_of_classes)
        l_b_onehot = one_hot(l, num_of_classes)
        
        l_a = l.unsqueeze(1).repeat(1,t_label.shape[1],1).reshape(n_t,batch_size).T.flatten()
        l_b = l.unsqueeze(1).repeat(1,t_label.shape[1],1).reshape(n_t,batch_size).T.flatten()
        
        bridge = (brownian_bridge_ab(t, x_a, x_b, var)[0]).reshape(-1, *x.shape[1:])
        y_bridge = (brownian_bridge_ab(t_label, l_a_onehot, l_b_onehot, var, simplex=True)[0]).reshape(-1, *l_a_onehot.shape[1:])
        lam = None
        
        xmix = bridge
        ymix = torch.ones_like(y_bridge)*(1/num_of_classes)
        y_a = None
        y_b = None
    return xmix, ymix, y_a, y_b, lam, idx, t, n_t


def mixup_criterion(y_a, y_b, lam):
    return lambda criterion, pred: (lam * criterion(pred, y_a) +  (1 - lam) * criterion(pred, y_b))

def cutmix(x, l, beta_param, device, weight):
    batch_size = x.shape[0]
    idx = move_to(torch.randperm(batch_size), device)
    l2 = l[idx]
    weight2 = weight[idx]
    lam = np.random.beta(beta_param, beta_param)
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x[:, :, bbx1:bbx2, bby1:bby2] = x[idx, :, bbx1:bbx2, bby1:bby2]
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
    return x, l, l2, lam, weight2

def cutmix_bridge(x, l, device, num_of_classes, terminal_T=1, n_t=5, var=0.0001):
    batch_size = x.shape[0]
    d = int(torch.prod(torch.tensor(x.shape[1:])))
    t = torch.linspace(0, terminal_T, n_t).unsqueeze(0).unsqueeze(-1).repeat(batch_size, 1, d).float().to(device)
    t_label = torch.linspace(0, terminal_T, n_t).unsqueeze(0).unsqueeze(-1).repeat(batch_size, 1, 1).float().to(device)
    idx = move_to(torch.randperm(batch_size), device)
    
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), 0.5)
    x[:, :, bbx1:bbx2, bby1:bby2] = x[idx, :, bbx1:bbx2, bby1:bby2]
    x_a = x.flatten(1)
    x_b = x[idx].flatten(1)
    
    l_a_onehot = one_hot(l, num_of_classes)
    l_b_onehot = one_hot(l[idx], num_of_classes)
    
    l_a = l.unsqueeze(1).repeat(1,t_label.shape[1],1).reshape(n_t,batch_size).T.flatten()
    l_b = l[idx].unsqueeze(1).repeat(1,t_label.shape[1],1).reshape(n_t,batch_size).T.flatten()
    
    bridge = (brownian_bridge_ab(t, x_a, x_b, var)[0]).reshape(-1, *x.shape[1:])
    y_bridge = (brownian_bridge_ab(t_label, l_a_onehot, l_b_onehot, var, simplex=True)[0]).reshape(-1, *l_a_onehot.shape[1:])
    lam = ((terminal_T - t_label[...,0])/terminal_T).reshape(-1)
    
    xmix = bridge
    ymix = y_bridge
    y_a = l_a
    y_b = l_b
    
    return xmix, y_bridge, y_a, y_b, lam, idx, t, n_t

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)
    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2