from datetime import datetime
from .utils import Yaccu,calculate_time,mixup_data
from .gradient_correction import GradientCorrection
import torch
import copy
import numpy as np
import ipdb
from .hilbert_projection import HilbertProjection
import time
from timeit import default_timer as timer
import concurrent.futures
time_stamp = datetime.today().strftime('%Y-%m-%d-%H:%M')

Id = lambda x: x

class KernelModel:
    def __init__(
            self, knots, kernel_fn, gradient_correction_fn=None, n_labels=1,
            device=[torch.device('cpu')],
            projection_device=torch.device("cpu"), projection_cutoff=1 * 10 ** -5,
            wandb_run=None,track_time=True,score_method = 'argmax', task = "classification",normalize_batch = True,
            loss_with_std = False,train_size = None,augment=0,augment_alpha=0
        ):
        super().__init__()

        self.task = task
        self.normalize_batch = normalize_batch
        self.augment = augment
        self.augment_alpha = augment_alpha

        self.tracktime = track_time
        self.correction_fn = Id if gradient_correction_fn is None else gradient_correction_fn
        self.lr = 0.1 if gradient_correction_fn is None else self.correction_fn.lr
        self.batch_size = self.correction_fn.batch_size
        self.kernel_fn = kernel_fn
        self.pinned_list = []
        self.device = device
        self.n_labels = n_labels

        self.train_size = train_size

        self.score_method = score_method

        self.update_cutoff = 0

        self.knots = knots#self.tensor(knots, release=True)
        self.knots_inner = knots[-1].to(projection_device)
        self.projection_device = projection_device
        self.projection_cutoff = torch.tensor(projection_cutoff)
        self.weights = self.tensor(torch.zeros(
            len(self.knots[0]), self.n_labels), release=True)
        self.weights_hist = self.weights.detach().clone()
        self.prediction_knots = self.tensor(torch.zeros(
            len(self.knots[0]), self.n_labels), release=True) # function value on knots

        self.projection_model = HilbertProjection(self.kernel_fn,
                                      self.knots[-1], n_labels, device=self.projection_device,wandb=wandb_run)

        self.epoch = 0
        self.loss_with_std = loss_with_std
        self.wandb_run = wandb_run
        if wandb_run is not None:
            self.wandb_run.summary['effective batchsize'] = self.batch_size*len(self.device)
            self.wandb_run.summary['learning rate'] = self.lr
            self.wandb_run.summary['new top eigenvalue'] = self.correction_fn.tail_eig_x
            self.wandb_run.summary['original top eigenvalue'] = self.correction_fn.Lam_x[0]


        self.time_track_dict = {'get_gradient':[0,0],'evaluate_corrected_gradient_at_knots':[0,0],
                           'predict':[0,0],'fit_batch':[0,0],'fit_epoch':[0,0],'k_z_xbatch_eval':[0,0],
                           'get_gz':[0,0],'update_projection_cutoff':[0,0],'update_weights':[0,0]}


    def time_track_wandb(self,dict_time):
        for key, value in dict_time.items():
            self.wandb_run.define_metric("epoch")
            self.wandb_run.define_metric(key, step_metric="epoch")
            log_dict = {'avgTPC: '+key: value[1], "epoch": self.epoch}
            self.wandb_run.log(log_dict)
            self.wandb_run.summary['CPE: ' + key] = value[0]
            dict_time[key] = [0, 0]

    def one_hot(self, labels,ind):
        return torch.nn.functional.one_hot(labels.long(), self.n_labels).to(self.device[ind])

    @calculate_time
    def k_z_xbatch_eval(self,X_batch,ind):
        #self.k_z_xbatch  = self.kernel_fn(self.knots, X_batch)
        return self.kernel_fn(self.knots[ind], X_batch)
    @calculate_time
    def get_gradient(self,y_batch,k_z_xbatch,ind):
        mult = k_z_xbatch.T @ self.weights.to(self.device[ind])
        if self.task == "classification":
            out = mult - self.one_hot(y_batch,ind)
        else:
            out = mult - y_batch
        return  out

    @calculate_time
    def get_gz(self,grad,k_z_xbatch):
        return k_z_xbatch @ grad
    @calculate_time
    def evaluate_corrected_gradient_at_knots(self, X_batch, grad,k_z_xbatch,ind):
        Kmat_xs_xbatch = self.correction_fn.Kmat_xs_xbatch_eval(X_batch,ind) #c1
        gtilde_z = self.correction_fn(grad,Kmat_xs_xbatch,ind)
        gz =  self.get_gz(grad,k_z_xbatch) #G3

        del Kmat_xs_xbatch
        torch.cuda.empty_cache()

        return gz - gtilde_z

    def tensor(self, data, dtype=None, release=False):
        if torch.is_tensor(data):
            tensor = data.detach().clone().to(dtype=dtype, device=self.device[0])
        else:
            tensor = torch.tensor(data, dtype=dtype,
                                  requires_grad=False).to(self.device[0])
        if release:
            self.pinned_list.append(tensor)
        return tensor

    @calculate_time
    def predict(self, X,g):
        return self.kernel_fn(X, self.knots[0].to(g)) @ self.weights.to(g)

    @calculate_time
    def update_projection_cutoff(self):

        permutation = torch.randperm(self.knots[0].size()[0])#, device=self.device)
        knots_randomind = permutation[0:max(self.knots[0].shape[0] // 100, 1000)]

        random_knots = self.knots[0][knots_randomind]
        alpha_dif = self.weights_hist[knots_randomind] - self.weights[knots_randomind]

        kmat_tmp = self.kernel_fn(random_knots, random_knots)

        alpha_dist = torch.sum(alpha_dif.T @
                               kmat_tmp @
                                alpha_dif) / self.n_labels


        self.projection_cutoff = alpha_dist / 100



    @calculate_time
    def fit_batch(self, X_batch, y_batch,ind):
        if self.augment:
            print(f'alpha augment = {self.augment_alpha}')
            X_batch, y_a, y_b, lam = mixup_data(X_batch, y_batch,
                                                           self.augment_alpha, self.device[ind])
        k_z_xbatch = self.k_z_xbatch_eval(X_batch,ind)

        if self.augment:
            grad = lam*self.get_gradient(y_a, k_z_xbatch, ind) + (1-lam)*self.get_gradient(y_b, k_z_xbatch, ind)
        else:
            grad = self.get_gradient( y_batch,k_z_xbatch,ind) #G2

        corrected_gz = self.evaluate_corrected_gradient_at_knots(
            X_batch, grad,k_z_xbatch,ind)
        del k_z_xbatch
        torch.cuda.empty_cache()
        #
        self.corrected_gz_scaled += (self.lr / (self.batch_size*len(self.device)) ) * corrected_gz.to(self.device[0])
        return grad,self.corrected_gz_scaled

    @calculate_time
    def update_weights(self):

        if self.update_cutoff:
            self.update_projection_cutoff()

        cut_off_tmp = self.projection_cutoff.detach().clone().to(self.projection_device)
        gz_projection = (self.prediction_knots - self.corrected_gz_scaled).to(self.projection_device)

        weights_tmp, prediction_knots_tmp = self.projection_model.fit_hilbert_projection(
            self.knots_inner,
            gz_projection,
            cutoff=cut_off_tmp,mem_gb=30,
            return_log=False
        )

        self.weights_hist = self.weights.detach().clone()
        self.weights = (weights_tmp).to(self.device[0])
        self.prediction_knots= (prediction_knots_tmp).to(self.device[0])


    @calculate_time
    def fit_epoch(self, train_loader,valid_loader,epochs):

        for batch_num,(X_batch,y_batch) in enumerate(train_loader):
            self.data_count += X_batch.shape[0]
            self.corrected_gz_scaled = 0
            X_batch = X_batch.view(X_batch.shape[0],-1).to(self.device[0])
            y_batch = y_batch.to(self.device[0])

            self.fit_batch(X_batch,y_batch,0)

            self.update_weights()

            print(f'Fit : batch {batch_num+1} of {len(train_loader) }')

            self.update_cutoff = 1

            torch.cuda.synchronize(device=self.projection_device)
            del X_batch,y_batch

            print(f'data_count:{self.data_count}')

            if len(self.acc_valid)<(self.data_count)//self.knots[0].shape[0] and self.data_count>=self.knots[0].shape[0]*(len(self.acc_valid)+1):
                print(f'self.data_count:{self.data_count}')
                if self.data_count%self.knots[0].shape[0]==0:
                    t = ( self.data_count//self.knots[0].shape[0] )-1 
                else:
                    t = (self.data_count // self.knots[0].shape[0])
                self.acc_valid.append(self.score(valid_loader,method = self.score_method))
                if self.wandb_run is not None:
                    # train_loss = self.loss(train_loader)
                    self.valid_loss = self.loss(valid_loader)
                    self.wandb_run.define_metric("epoch")
                    if self.task == "classification":
                        self.wandb_run.define_metric("validation_accuracy", step_metric="epoch")
                    self.wandb_run.define_metric("validation_loss_std-rmse", step_metric="epoch")
                    self.wandb_run.define_metric("weight_norm", step_metric="epoch")
                    self.wandb_run.define_metric("projection_cutoff", step_metric="epoch")

                    log_dict = {'validation_accuracy': self.acc_valid[t-1],
                                'validation_loss_std-rmse': self.valid_loss,
                                'weight_norm': torch.norm(self.weights),
                                'projection_cutoff':self.projection_cutoff,
                                "epoch": t
                    }
                    self.wandb_run.log(log_dict)

                    print(f'Fit: end of epoch {t} of {epochs*train_loader.dataset.x.shape[0]//self.knots[0].shape[0]}'+
                          f' valid acc = {self.acc_valid[t-1]}' if valid_loader is not None else '')






    def score(self, dataloader,method = 'argmax'):

        accu = 0
        data_size = 0
        for X_batch,y_batch in dataloader:
            y_batch = y_batch.to(self.device[0])
            X_batch = X_batch.view(X_batch.shape[0],-1).to(self.device[0])
            if self.task == "classification":
                y_batch = self.one_hot(self.tensor(y_batch, dtype=torch.long),0)

            yhat = self.predict(X_batch,0)
            data_size += yhat.shape[0]
            if method=='argmax':
                yhat_sign = Yaccu(yhat)
                accu += sum(yhat_sign == Yaccu(y_batch))
            elif method=='top5':
                yhat_sign = Yaccu(yhat,method)
                true_y = Yaccu(y_batch,method= 'argmax')
                for i in range(true_y.shape[0]):
                    if true_y[i] in yhat_sign[i,:]:
                        accu+=1
            del X_batch,y_batch
            torch.cuda.empty_cache()
        if method=='argmax':
            accu = accu.cpu().detach().item() / data_size
        else:
            accu = accu / data_size
        return accu


    def loss(self, dataloader):
        loss = 0.0
        loss_mse = torch.nn.MSELoss(reduction='sum')
        data_size = 0
        for X_batch, y_batch in dataloader:
            y_batch = y_batch.to(self.device[0])
            X_batch = X_batch.view(X_batch.shape[0],-1).to(self.device[0])
            # batch_num = i

            if self.task == "classification":
                y_batch = self.one_hot(self.tensor(y_batch, dtype=torch.long),0)
            yhat = self.predict(X_batch,0)
            loss += loss_mse(yhat, y_batch).cpu().detach().item()
            data_size += yhat.shape[0]
            del X_batch, y_batch,yhat
            torch.cuda.empty_cache()

        loss /= data_size
        if self.loss_with_std:
            loss*=dataloader.std_y_tr.cpu().detach().item()**2
        return loss

    def fit(self, train_loader, epochs=1,
            valid_loader=None):

        self.acc_valid = []
        self.data_count = 0
        for t in range(epochs):
            self.epoch = t
            self.fit_epoch(train_loader,valid_loader,epochs)

            if self.tracktime:
                self.time_track_wandb(self.time_track_dict)
                self.time_track_wandb(self.correction_fn.time_track_dict)
                self.time_track_wandb(self.projection_model.time_track_dict)

        return self.weights, self.acc_valid,self.valid_loss


if __name__ == "__main__":
    from .hilbert_projection import HilbertProjection
    from .gradient_correction import GradientCorrection
    from .utils.cifar10 import Cifar10DataLoader
    import torch
    import numpy as np
    from .utils.kernel import gaussian
    kernel_fn = lambda x, y: gaussian(x, y, bandwidth=1)
    torch.manual_seed(120)
    np.random.seed(120)
    cifar10_validloader = Cifar10DataLoader(
        parts=1, train=True, validation=True, batch_size=256, shuffle=True,
    )
    ids, X, y = next(iter(cifar10_validloader))
    n_labels = len(y.unique())
    randomind = np.random.choice(range(len(y)), size=100, replace=False)
    knots = X[randomind]
    nystrom_ind = np.random.choice(range(len(y)), size=100, replace=False)
    Px = GradientCorrection(
        kernel_fn=kernel_fn,
        nystrom_samples=X[nystrom_ind], level=2,
        knots=knots, device="cpu",
    )
    model = KernelModel(
        knots, kernel_fn, n_labels=n_labels,
        gradient_correction_fn=Px,
    )
    alpha, acc_valid = model.fit(cifar10_validloader, epochs=2)
