from datetime import datetime
from .utils import Yaccu,calculate_time
from .gradient_correction import GradientCorrection
from .utils.datasets import taxidataset_all,Cifar5mDataset,dataset_custom
import torch
import copy
import numpy as np
from torch.nn.functional import one_hot
import ipdb
import os
import pickle
from datetime import datetime
from .hilbert_projection import HilbertProjection
# from .hilbert_projection import KernelRegressionModel
import time
from timeit import default_timer as timer
import concurrent.futures
time_stamp = datetime.today().strftime('%Y-%m-%d-%H:%M')

Id = lambda x: x
RESULTS_LOCAL_DIRNAME = ''

class KernelModel:
    def __init__(
            self, knots, kernel_fn, gradient_correction_fn=None, n_labels=1,
            device=[torch.device('cpu')],
            projection_device=torch.device("cpu"), projection_cutoff=1 * 10 ** -5,
            wandb_run=None,track_time=True,score_method = 'argmax', task = "classification",normalize_batch = True,
            loss_with_std = False,train_size = None,knots_y=None,checkpoint=None,dataset_name=None,kernel_type=None
        ):
        super().__init__()


        self.kernel_type = kernel_type
        self.dataset_name = dataset_name
        self.load_checkpoint = checkpoint

        self.task = task
        self.normalize_batch = normalize_batch

        self.tracktime = track_time
        self.correction_fn = Id if gradient_correction_fn is None else gradient_correction_fn
        self.lr = 0.1 if gradient_correction_fn is None else self.correction_fn.lr
        self.batch_size = self.correction_fn.batch_size
        self.kernel_fn = kernel_fn
        self.pinned_list = []
        self.device = device
        self.n_labels = n_labels

        self.train_size = train_size

        self.score_method = score_method

        self.update_cutoff = 0

        self.knots = knots#self.tensor(knots, release=True)
        self.knots_y = knots_y
        self.knots_inner = knots[-1].to(projection_device[0])
        self.projection_device = projection_device
        self.projection_cutoff = torch.tensor(projection_cutoff)

        
        if self.load_checkpoint is not None:
            print("alpha loads from checkpoint...")
            with open(self.load_checkpoint, 'rb') as fp:
                checkpoint = pickle.load(fp)
            self.weights = checkpoint['alpha']
            del checkpoint
        else:
            self.weights = self.tensor(torch.zeros(
                len(self.knots[0]), self.n_labels), release=True)

        self.theta2_hist = self.weights.detach().clone()

        self.weights_hist = self.weights.detach().clone()


        self.projection_model = HilbertProjection(self.kernel_fn,
                                     self.knots_inner, n_labels, device=self.projection_device[0],wandb=wandb_run)


        print("gettign prediction knots")
        if self.load_checkpoint is not None:
            self.prediction_knots,_ = self.projection_model.get_predictions_alpha(self.weights)
            self.prediction_knots = self.prediction_knots.to(self.device[0])
        else:
            self.prediction_knots = self.tensor(torch.zeros(
                len(self.knots[0]), self.n_labels), release=True)


        self.epoch = 0
        self.loss_with_std = loss_with_std
        self.wandb_run = wandb_run
        if wandb_run is not None:
            self.wandb_run.summary['learning rate'] = self.lr
            self.wandb_run.summary['new top eigenvalue'] = self.correction_fn.tail_eig_x
            self.wandb_run.summary['original top eigenvalue'] = self.correction_fn.Lam_x[0]


        self.time_track_dict = {'get_gradient':[0,0],'evaluate_corrected_gradient_at_knots':[0,0],
                           'predict':[0,0],'fit_batch':[0,0],'fit_epoch':[0,0],'k_z_xbatch_eval':[0,0],
                           'get_gz':[0,0],'update_projection_cutoff':[0,0],'update_weights':[0,0]}


    def time_track_wandb(self,dict_time):
        for key, value in dict_time.items():
            self.wandb_run.define_metric("epoch")
            self.wandb_run.define_metric(key, step_metric="epoch")
            log_dict = {'avgTPC: '+key: value[1], "epoch": self.epoch}
            self.wandb_run.log(log_dict)
            self.wandb_run.summary['CPE: ' + key] = value[0]
            dict_time[key] = [0, 0]

    def one_hot(self, labels,ind):
        return torch.nn.functional.one_hot(labels.long(), self.n_labels).to(self.device[ind])

    @calculate_time
    def k_z_xbatch_eval(self,X_batch,ind):
        #self.k_z_xbatch  = self.kernel_fn(self.knots, X_batch)
        return self.kernel_fn(self.knots[ind], X_batch)
    @calculate_time
    def get_gradient(self,y_batch,k_z_xbatch,ind):
        mult = k_z_xbatch.T @ self.weights.to(self.device[ind])
        if self.task == "classification":
            out = mult - self.one_hot(y_batch,ind)
        else:
            out = mult - y_batch
        return  out

    @calculate_time
    def get_gz(self,grad,k_z_xbatch):
        return k_z_xbatch @ grad
    @calculate_time
    def evaluate_corrected_gradient_at_knots(self, X_batch, grad,k_z_xbatch,ind):

        Kmat_xs_xbatch = self.correction_fn.Kmat_xs_xbatch_eval(X_batch,ind) #c1
        gtilde_z = self.correction_fn(grad,Kmat_xs_xbatch,ind)
        gz =  self.get_gz(grad,k_z_xbatch) #G3

        del Kmat_xs_xbatch
        torch.cuda.empty_cache()

        return gz - gtilde_z



    def tensor(self, data, dtype=None, release=False):
        if torch.is_tensor(data):
            tensor = data.detach().clone().to(dtype=dtype, device=self.device[0])
        else:
            tensor = torch.tensor(data, dtype=dtype,
                                  requires_grad=False).to(self.device[0])
        if release:
            self.pinned_list.append(tensor)
        return tensor

    @calculate_time
    def predict(self, X):
        return self.kernel_fn(X, self.knots[0]) @ self.weights

    @calculate_time
    def update_projection_cutoff(self):

        permutation = torch.randperm(self.knots[0].size()[0])#, device=self.device)
        knots_randomind = permutation[0:max(self.knots[0].shape[0] // 100, 1000)]

        random_knots = self.knots[0][knots_randomind]

        alpha_dif = self.theta2_hist[knots_randomind] - self.theta2[knots_randomind]

        kmat_tmp = self.kernel_fn(random_knots, random_knots)

        alpha_dist = torch.sum(alpha_dif.T @
                               kmat_tmp @
                                alpha_dif) / self.n_labels


        self.projection_cutoff = alpha_dist / 10



    @calculate_time
    def fit_batch(self, X_batch, y_batch,ind):

        k_z_xbatch = self.k_z_xbatch_eval(X_batch,ind)
        grad = self.get_gradient( y_batch,k_z_xbatch,ind) 
        corrected_gz = self.evaluate_corrected_gradient_at_knots(
            X_batch, grad,k_z_xbatch,ind)
        del k_z_xbatch
        torch.cuda.empty_cache()

        self.corrected_gz_scaled += corrected_gz.to(self.device[0])


    @calculate_time
    def update_weights(self):

        cut_off_tmp = max(1/(1+self.totalsteps**1.5),1e-6)

        gz_projection = self.corrected_gz_scaled.to(self.projection_device[0])

        self.theta2,_ = self.projection_model.fit_hilbert_projection(
           self.knots_inner,
           gz_projection,
           cutoff=cut_off_tmp,mem_gb=33,
           return_log=False
        )
        new_weights = self.weights -  (self.lr / (self.batch_size*len(self.device)) )*self.theta2.to(self.device[0])


        self.weights_hist = self.weights.detach().clone()
        self.weights = (new_weights).to(self.device[0])



    @calculate_time
    def fit_epoch(self, train_loaders,valid_loader,checkpoint=50):

        batch_num = 0
        for tainloader_ind,train_loader in enumerate(train_loaders):
            permutation = torch.randperm(train_loader.X.size()[0])
            for i in range(0, train_loader.X.size()[0], len(self.device)*int(self.batch_size)):

                batch_ids = permutation[i:i + len(self.device)*int(self.batch_size)]
                self.corrected_gz_scaled = 0
                if self.normalize_batch :
                    X_batch_all = [None]* len(self.device)
                    y_batch_all = [None] * len(self.device)
                    indexes = [None] * len(self.device)
                    X_batch = train_loader.X[batch_ids]
                    y_batch = train_loader.y[batch_ids]

                    for ind,g in enumerate(self.device):
                        if tainloader_ind != ind:
                            X_batch_all[ind] = X_batch[ind*self.batch_size:(ind+1)*self.batch_size].to(g).float()/255
                            y_batch_all[ind] = y_batch[ind * self.batch_size:(ind + 1) * self.batch_size].to(g)
                            indexes[ind] = ind
                        else:
                            X_batch_all[ind] = X_batch[ind*self.batch_size:(ind+1)*self.batch_size].float()/255
                            y_batch_all[ind] = y_batch[ind * self.batch_size:(ind + 1) * self.batch_size]
                            indexes[ind] = ind
                else:
                    X_batch_all = [None]* len(self.device)
                    y_batch_all = [None] * len(self.device)
                    indexes = [None] * len(self.device)
                    X_batch = train_loader.X[batch_ids]
                    y_batch = train_loader.y[batch_ids]

                    for ind,g in enumerate(self.device):
                        if tainloader_ind != ind:
                            X_batch_all[ind] = X_batch[ind*self.batch_size:(ind+1)*self.batch_size].to(g)
                            y_batch_all[ind] = y_batch[ind * self.batch_size:(ind + 1) * self.batch_size].to(g)
                            indexes[ind] = ind
                        else:
                            X_batch_all[ind] = X_batch[ind*self.batch_size:(ind+1)*self.batch_size]
                            y_batch_all[ind] = y_batch[ind * self.batch_size:(ind + 1) * self.batch_size]
                            indexes[ind] = ind

                with concurrent.futures.ThreadPoolExecutor() as executor:
                    res = [executor.submit(self.fit_batch, input[0],input[1],input[2]) for input
                            in zip(*[X_batch_all,y_batch_all,indexes])]

                for g in self.device:
                    torch.cuda.synchronize(device=g)
                self.update_weights()

                self.totalsteps +=len(self.device)

                last_epoch = self.train_size//(len(self.device)*self.batch_size)
                if batch_num % 1 == 0:

                    print(f'Fit : batch {batch_num+1} of {self.train_size//(len(self.device)*self.batch_size) + len(self.device) }')
                if (batch_num)==last_epoch: 
                    self.acc_valid.append(self.score(valid_loader, method=self.score_method))
                    self.valid_loss.append(self.loss(valid_loader))
                    if self.task=="classification":
                        print(f'Fit: step {batch_num} '+
                                  f' valid acc = {self.acc_valid[-1]}' if valid_loader is not None else '')
                    else:
                        print(f'Fit: step {batch_num} '+
                                  f' validation_std_mse = {np.sqrt(self.valid_loss[-1])}' if valid_loader is not None else '')

                    self.wandb_run.define_metric("validation_accuracy", step_metric="{checkpoint}_step")
                    log_dict = {'validation_accuracy': self.acc_valid[-1],
                                'validation_loss': self.valid_loss[-1],
                                'validation_std_mse':np.sqrt(self.valid_loss[-1]),
                                'weight_norm': torch.norm(self.weights),
                                'projection_cutoff':self.projection_cutoff,
                                f'{checkpoint}_step': len(self.acc_valid)
                    }
                    self.wandb_run.log(log_dict)

                time_now = time.time()
                elapse = (time_now-self.time_start)//3600


                batch_num += 1

                self.update_cutoff = 1

                torch.cuda.synchronize(device=self.projection_device[0])
                del batch_ids,X_batch,y_batch





    def score(self, dataloader,method = 'argmax'):

        accu = 0
        permutation = torch.randperm(dataloader.X.size()[0])
        for batch_num,i in enumerate(range(0, dataloader.X.size()[0], int(self.batch_size))):
            batch_ids = permutation[i:i + int(self.batch_size)]
            if self.normalize_batch:
                X_batch = dataloader.X[batch_ids].float()/255
            else:
                X_batch = dataloader.X[batch_ids].to(self.device[0])
            y_batch = dataloader.y[batch_ids].to(self.device[0])

            if batch_num>10_000:
                break
            if self.task == "classification":
                y_batch = self.one_hot(self.tensor(y_batch, dtype=torch.long),0)

            yhat_test = self.predict(X_batch)
            if method=='argmax':
                yhat_test_sign = Yaccu(yhat_test)
                accu += sum(yhat_test_sign == Yaccu(y_batch))
            elif method=='top5':
                yhat_test_sign = Yaccu(yhat_test,method)
                true_y = Yaccu(y_batch,method= 'argmax')
                for i in range(true_y.shape[0]):
                    if true_y[i] in yhat_test_sign[i,:]:
                        accu+=1
            del X_batch,y_batch,batch_ids
            torch.cuda.empty_cache()
        if method=='argmax':
            accu = accu.cpu().detach().item() / dataloader.X.size()[0]
        else:
            accu = accu / dataloader.X.size()[0]
        return accu


    def loss(self, dataloader):
        loss = 0.0
        loss_mse = torch.nn.MSELoss(reduction='sum')
        data_size = 0

        permutation = torch.randperm(dataloader.X.size()[0])
        for i in range(0, dataloader.X.size()[0], int(self.batch_size)):
            batch_ids = permutation[i:i + int(self.batch_size)]
            if self.normalize_batch:
                X_batch = dataloader.X[batch_ids].float()/255
            else:
                X_batch = dataloader.X[batch_ids].to(self.device[0])
            y_batch = dataloader.y[batch_ids].to(self.device[0])
            if self.task == "classification":
                y_batch = self.one_hot(self.tensor(y_batch, dtype=torch.long),0)

            batchsize,n_classes=y_batch.size()
            yhat_test = self.predict(X_batch)

            loss += loss_mse(yhat_test,y_batch).cpu().detach().item()
            data_size += batchsize
            del X_batch, y_batch, batch_ids
            torch.cuda.empty_cache()

        loss /= data_size
        if self.loss_with_std:
            loss*=dataloader.std_y_tr.cpu().detach().item()**2
        return loss

    def fit(self, train_loaders, epochs=1,
            valid_loader=None):
        acc_train = torch.zeros(epochs) if eval else None
        acc_valid = torch.zeros(epochs) if eval else None
        self.totalsteps = 0
        self.time_start = time.time()
        self.acc_valid  = []
        self.valid_loss = []
        self.acc_valid.append(self.score(valid_loader, method=self.score_method))
        self.valid_loss.append(self.loss(valid_loader))
        print(f'Starting accuracy '+
                              f' valid acc = {self.acc_valid[-1]}' if valid_loader is not None else '')


        for t in range(epochs):
            self.epoch = t
            print(f'Fit: start of epoch {t+1} of {epochs}')
            self.fit_epoch(train_loaders,valid_loader)
            if (self.epoch+1)%5==0:
                result_dict = {'alpha': self.weights, 'acc_valid_ep3': self.acc_valid,
                               'nystrom_samples': self.correction_fn.nystrom_samples,
                               'knots_x': self.knots,"knots_y":self.knots_y}
                day_stamp = datetime.today().strftime('%m-%d-%y')

                dirname = os.path.join(os.environ['RESULTS_DIRNAME'], RESULTS_LOCAL_DIRNAME,
                                       day_stamp,self.dataset_name,self.kernel_type)
                if not os.path.exists(dirname): os.makedirs(dirname)
                filename = f'{time_stamp}-n={self.train_size}-p={self.knots[0].shape[0]}-checkpoint={self.epoch}'
                filename = os.path.join(dirname, filename)
                with open(f'{filename}.pickle', 'wb') as f_pkl:
                    pickle.dump(result_dict, f_pkl)

        return self.weights, self.acc_valid


if __name__ == "__main__":
    from .hilbert_projection import HilbertProjection
    from .gradient_correction import GradientCorrection
    from .utils.cifar10 import Cifar10DataLoader
    import torch
    import numpy as np
    from .utils.kernel import gaussian
    kernel_fn = lambda x, y: gaussian(x, y, bandwidth=10)
    torch.manual_seed(120)
    np.random.seed(120)

    devices = [torch.device('cpu')]
    X = torch.randn(1000,100)
    y = torch.randint(0, 10, (1000,))
    X_test = torch.randn(200,100).to(devices[0])
    y_test = torch.randint(0, 10, (200,)).to(devices[0])
    randomind = np.random.choice(range(len(y)), size=50, replace=False)
    knots_x = X[randomind].to(devices[0])
    knots_y = one_hot(y[randomind]).to(devices[0])
    # ipdb.set_trace()
    nystrom_ind = np.random.choice(range(len(y)), size=100, replace=False)


    gradient_correction_fn = GradientCorrection(
        kernel_fn=kernel_fn,
        nystrom_samples=X[nystrom_ind], level=2,
        knots=[knots_x], device=devices, batch_size=10,
        learning_rate_prefactor=0.99,wandb_run=None
    )
    model = KernelModel(
        kernel_fn=kernel_fn, knots=[knots_x],knots_y=knots_y,
        gradient_correction_fn=gradient_correction_fn,
        projection_cutoff=0.1, projection_device=devices, device=devices,
        n_labels=knots_y.shape[1],
        wandb_run=None, track_time=True,normalize_batch = False ,
        loss_with_std = False,
        train_size=X.shape[0]
    )
    traindatasets = []
    testdatasets = dataset_custom(X_test,y_test,knots_x=knots_x,knots_y=knots_y)
    traindatasets.append(dataset_custom(X.to(devices[0]), y.to(devices[0])))

    alpha, acc_valid_ep3,predictions = model.fit(
        traindatasets, epochs=10, valid_loader=testdatasets)


