import torch
from torch.optim import SGD, Adam
from collections import Counter
from utils.utils import *
from tqdm import tqdm

class Pipeline:
    def __init__(self, task_list, args, train_data, test_data, model):
        # noise cannot be used without use_md
        if args.noise: assert args.use_md

        self.args = args
        self.train_data = train_data
        self.test_data = test_data
        self.model = model

        self.preprocessed_task = -1
        self.received_data = -1

        self.train_loaders = []
        self.test_loaders = []
        self.calibration_loaders = []

        self.zeroshot = Zeroshot(self.args.model_clip, args)

        self.trackers = {
            'cil_tracker': Tracker(args),
            'til_tracker': Tracker(args),
            'cal_cil_tracker': Tracker(args),
            'auc_softmax_tracker': AUCTracker(args),
            'auc_md_tracker': AUCTracker(args),
            'openworld_softmax_tracker': OWTracker(args),
        }

        cum_acc_list, total_loss_list, iter_list, total_iter = [], [], [], 0

        self.args.mean, self.args.cov, self.args.cov_inv = {}, {}, {}
        self.args.mean_task, self.args.cov_noise, self.args.cov_inv_noise = {}, {}, {}

        param_copy = None

        combined_sigma = 0

        self.init_task = 0
        self.init_epoch = 0

        # Prepare all the data loaders for convenience
        for t in range(self.args.n_tasks): self.receive_data(t)

        if any(
            [self.args.resume_id is not None,
            self.args.test_task_id is not None,
            self.args.train_clf_pree_id is not None,]
            ):
            if self.args.load_path is None:
                self.args.load_path = self.args.logger.dir()

        if self.args.resume_id is not None:
            self.load_task(self.args.resume_id)
        # elif self.args.test_task_id is not None:
        #     self.load_task(self.args.test_task_id)

    def single_head(self, task_id):
        self.model.preprocess_single_head(task_id)

        if self.args.dataset == 'cifar10':
            steps_per_epoch = int(5000 * self.args.num_cls_per_task * (task_id + 1) / self.args.batch_size)
        elif self.args.dataset == 'cifar100' or self.args.dataset == 'timgnet':
            steps_per_epoch = int(500 * self.args.num_cls_per_task * (task_id + 1) / self.args.batch_size)
        elif self.args.dataset == 'imgnet380':
            steps_per_epoch = int(1300 * self.args.num_cls_per_task * (task_id + 1) / self.args.batch_size)

        for epoch in range(self.args.n_epochs):
            loss_total = 0
            for _ in range(steps_per_epoch):
                loss_total += self.model.train_single_head(task_id)
            if (epoch + 1) % self.args.eval_every == 0:
                metrics = self.model.acc()
                self.args.logger.print("Task {} | {}/{} | Loss {:.4f} | CIL {:.2f} | TIL {:.2f}".format(
                                                        task_id,
                                                        epoch + 1,
                                                        self.args.n_epochs,
                                                        loss_total,
                                                        metrics['cil_acc'],
                                                        metrics['til_acc'],
                                                        ))

        state_dict = self.model.net.head_state_dict()
        checkpoint = {
            'state_dict': state_dict
        }

        if self.args.generate_ood:
            name = self.args.single_head_model_name + f"_ood_{task_id}"
        else:
            name = self.args.single_head_model_name + f"_{task_id}"

        torch.save(checkpoint,
                os.path.join(self.args.logger.dir(), name)
                )

    def receive_all(self, task_id):
        for p_task_id in range(task_id + 1):
            self.receive_data(p_task_id)

    def receive_data(self, task_id):
        if len(self.train_loaders) <= task_id:
            self.args.logger.print(f"Received data of task {task_id}")
            if self.args.validation is None:
                t_train = self.train_data.make_dataset(task_id)
                t_test = self.test_data.make_dataset(task_id)
            else:
                t_train, t_test = self.train_data.make_dataset(task_id)

            if self.args.calibration:
                assert self.args.cal_batch_size > 0
                assert self.args.cal_epochs > 0
                assert self.args.cal_size > 0
                t_train, t_cal = calibration_dataset(args, t_train)
                self.calibration_loaders.append(make_loader(t_cal, self.args, train='calibration'))

            self.train_loaders.append(make_loader(t_train, self.args, train='train'))
            self.test_loaders.append(make_loader(t_test, self.args, train='test'))

            self.current_train_loader = deepcopy(self.train_loaders[task_id])

            # self.received_data += 1

    def load_train_step(self, task_id):
        checkpoint = custom_load(os.path.join(self.args.load_path, f'model_task_{task_id}'))

        # The old version only has model's state_dict. Check if it's old version or not
        if 'optimizer' not in checkpoint:
            # self.model.net.load_state_dict(checkpoint)
            custom_load_state_dict(self.args, self.model.net, checkpoint)
        else:
            custom_load_state_dict(self.args, self.model.net, checkpoint['state_dict'])
            # self.model.net.load_state_dict(checkpoint['state_dict'])
            # custom_load_state_dict(self.args, self.model.optimizer, checkpoint['optimizer'])
            # self.model.optimizer.load_state_dict(checkpoint['optimizer'])
            self.init_task = checkpoint['task_id']
            self.init_epoch = checkpoint['epoch']

            if self.init_epoch >= self.args.n_epochs:
                self.init_task += 1
                self.init_epoch = 0

    def load_model_step(self, task_id=None):
        checkpoint = custom_load(os.path.join(self.args.load_path, f'saving_buffer_{task_id}'))
        for key, val in checkpoint.items():
            if hasattr(self.model, key):
                self.args.logger.print(f"** {self.model.__class__.__name__}: Update {key} values **")
                setattr(self.model, key, val)
            else:
                self.args.logger.print(f"** WARNING: {self.model.__class__.__name__}: {key} values are not updated **")

    def load_trackers(self, task_id=None):
        for k, v in self.trackers.items():
            try:
                v.mat = torch.load(os.path.join(self.args.logger.dir(), k))
                self.args.logger.print(f"Loaded {k}")
            except FileNotFoundError:
                pass

    def load_task(self, task_id):
        # self.receive_all(task_id)

        self.load_all_MD_stats(task_id)

        self.preprocess_all_tasks(task_id)

        self.load_train_step(task_id)
        self.load_model_step(task_id)
        self.load_trackers(task_id)

    def inputs_to_features(self, inputs):
        with torch.no_grad():
            if self.args.model_vit:
                return self.args.model_vit.forward_features(inputs)
            else:
                return inputs

    def test_task(self, test_loader, **kwargs):
        """
            test_loader: test loader. This doesn't necessarily have to be same as task_id.
        """
        self.model.reset_eval()

        for _, x, y, _, _, _ in test_loader:
            x, y = x.to(self.args.device), y.to(self.args.device)

            x = self.inputs_to_features(x)

            if self.args.zero_shot:
                raise NotImplementedError()

            self.model.evaluate(x, y, **kwargs)

        metrics = self.model.acc()

        return metrics

    def test_auc(self, task_id, metrics=None, **kwargs):
        """
            task_id: the task id the model will use (ind)
        """
        try:
            epoch = kwargs['epoch']
        except KeyError:
            epoch = self.args.n_epochs - 1

        inputs_evaluate = {
            'task_id': task_id,
            'report_cil': True,
            'total_learned_task_id': self.preprocessed_task,
            'ensemble': False,
            'true_id': task_id,
            }

        if metrics is None: metrics = self.test_task(self.test_loaders[task_id], **inputs_evaluate)

        self.trackers['cil_tracker'].update(metrics['cil_acc'], task_id, task_id)
        self.trackers['til_tracker'].update(metrics['til_acc'], task_id, task_id)

        in_scores = metrics['scores']
        if self.args.compute_md: in_scores_md = metrics['scores_md']

        if self.args.save_cil_md:
            save_cil_md_path = os.path.join(self.args.logger.dir(),
                                            'cil_md_output_all')
            if not os.path.isdir(save_cil_md_path): os.mkdir(save_cil_md_path)
            torch.save(metrics['cil_output_save'], 
                        os.path.join(save_cil_md_path, f'cil_output_save_learned_{self.preprocessed_task}_data_{task_id}'))
            torch.save(metrics['md_output_save'],
                        os.path.join(save_cil_md_path, f'md_output_save_learned_{self.preprocessed_task}_data_{task_id}'))
            torch.save(metrics['label_save'],
                os.path.join(save_cil_md_path, f'label_save_learned_{self.preprocessed_task}_data_{task_id}'))

        auc_list, auc_list_md = [], []
        auc_total_in_list, auc_total_out_list, out_id_list = [metrics['scores_total']], [], []

        for task_out in range(self.args.n_tasks):
            if task_out != task_id:
                if self.args.validation is None:
                    t_test = self.test_data.make_dataset(task_out)
                else:
                    _, t_test = self.train_data.make_dataset(task_out)

                ood_loader = make_loader(t_test, self.args, train='test')

                inputs_evaluate = {
                    'task_id': task_id,
                    'report_cil': True,
                    'total_learned_task_id': task_id,
                    'ensemble': False,
                    'true_id': task_out if task_out < task_id else None
                    }

                metrics = self.test_task(ood_loader, **inputs_evaluate)

                if task_out < task_id:
                    self.trackers['cil_tracker'].update(metrics['cil_acc'], task_id, task_out)
                    self.trackers['til_tracker'].update(metrics['til_acc'], task_id, task_out)

                    if self.args.save_cil_md:
                        save_cil_md_path = os.path.join(self.args.logger.dir(),
                                                        'cil_md_output_all')
                        if not os.path.isdir(save_cil_md_path): os.mkdir(save_cil_md_path)
                        torch.save(metrics['cil_output_save'], 
                                    os.path.join(save_cil_md_path, f'cil_output_save_learned_{self.preprocessed_task}_data_{task_out}'))
                        torch.save(metrics['md_output_save'],
                                    os.path.join(save_cil_md_path, f'md_output_save_learned_{self.preprocessed_task}_data_{task_out}'))
                        torch.save(metrics['label_save'],
                            os.path.join(save_cil_md_path, f'label_save_learned_{self.preprocessed_task}_data_{task_out}'))

                out_scores = metrics['scores']
                auc = compute_auc(in_scores, out_scores)
                auc_list.append(auc * 100)
                self.args.logger.print("Epoch {}/{} | in/out: {}/{} | Softmax AUC: {:.2f}".format(epoch + 1, self.args.n_epochs, task_id, task_out, auc_list[-1]), end=' ')
                self.trackers['auc_softmax_tracker'].update(auc_list[-1], task_id, task_out)

                if self.args.compute_md:
                    out_scores_md = metrics['scores_md']
                    auc_md = compute_auc(in_scores_md, out_scores_md)
                    auc_list_md.append(auc_md * 100)
                    self.args.logger.print("| MD AUC: {:.2f}".format(auc_list_md[-1]))
                    self.trackers['auc_md_tracker'].update(auc_list_md[-1], task_id, task_out)
                else:
                    self.args.logger.print('')

                if task_out <= task_id:
                    auc_total_in_list.append(metrics['scores_total'])
                else:
                    auc_total_out_list.append(metrics['scores_total'])
                    out_id_list.append(task_out)

        self.args.logger.print("Epoch {}/{} | Average Softmax AUC: {:.2f}".format(epoch + 1, self.args.n_epochs, np.array(auc_list).mean()), end=' ')
        if self.args.compute_md:
            self.args.logger.print("| Average MD AUC: {:.2f}".format(np.array(auc_list_md).mean()))
        else:
            self.args.logger.print('')

        for task_out, out_scores in zip(out_id_list, auc_total_out_list):
            auc = compute_auc(auc_total_in_list, out_scores)
            self.args.logger.print("Epoch {}/{} | total in/out: {}/{} | AUC: {:.2f}".format(epoch + 1, self.args.n_epochs, task_id, task_out, auc * 100))
            self.trackers['openworld_softmax_tracker'].update(auc * 100, task_id, task_out)
        if len(auc_total_in_list) > 0 and len(auc_total_out_list) > 0:
            auc = compute_auc(auc_total_in_list, auc_total_out_list)
            self.args.logger.print("Epoch {}/{} | total in | AUC: {:.2f}".format(epoch + 1, self.args.n_epochs, auc * 100))

    def save_trackers(self, **kwargs):
        for k, v in self.trackers.items():
            torch.save(v.mat, os.path.join(self.args.logger.dir(), k))

    def save_train_step(self, task_id, epoch, **kwargs):
        # Save anything relevant to training steps (e.g., epochs, task_id, optim, etc.)
        if hasattr(self.model, 'save'):
            if hasattr(self.model.net, 'adapter_state_dict'):
                state_dict = self.model.net.adapter_state_dict()
            elif hasattr(self.model.net, 'prompt_head_state_dict'):
                state_dict = self.model.net.prompt_head_state_dict()
            else:
                self.args.logger.print("******** Warning: Saving the entire network ********")
                state_dict = self.model.net.state_dict()

        # training specific
        checkpoint = {
            'state_dict': state_dict,
            'optimizer': self.model.optimizer.state_dict(),
            'task_id': task_id,
            'epoch': epoch + 1,
            }

        torch.save(checkpoint,
                os.path.join(self.args.logger.dir(), f"model_task_{task_id}")
                )

    def save_model_step(self, task_id, **kwargs):
        # Save anything specific to current model
        self.model.save(task_id, margin=self.current_train_loader.dataset.margin if self.args.tr_dynamics else None)

    def save_all(self, **kwargs):
        self.save_train_step(**kwargs)

        self.save_model_step(**kwargs)

        # Save statistics e.g. mean, cov, cov_inv
        if self.args.save_statistics:
            np.save(os.path.join(self.args.logger.dir(), 'statistics'),
                    self.model.statistics)

        self.save_trackers(**kwargs)

    def load_task_MD_stats(self, task_id):
        if self.args.recompute_md:
            self.model.compute_stats(task_id, self.train_loaders[task_id])

        try:
            cov = custom_load(os.path.join(self.args.load_path,
                                        f'{self.args.cov_task_name}_{task_id}.npy'),
                                file_type='numpy')
            self.args.cov[task_id] = cov
            self.args.cov_inv[task_id] = np.linalg.inv(0.8 * cov + 0.2 * np.eye(len(cov)))
            if self.args.noise:
                self.args.logger.print("Importing Noise Stats")
                mean = custom_load(os.path.join(self.args.load_path,
                                            f'{self.args.mean_task_name}_{task_id}.npy'),
                                file_type='numpy')
                self.args.mean_task[task_id] = mean
                cov = custom_load(os.path.join(self.args.load_path,
                                            f'{self.args.cov_task_noise_name}_{task_id}.npy'),
                                file_type='numpy')
                self.args.cov_noise[task_id] = cov
                self.args.cov_inv_noise[task_id] = np.linalg.inv(cov)
            for y in range(task_id * self.args.num_cls_per_task, (task_id + 1) * self.args.num_cls_per_task):
                mean = custom_load(os.path.join(self.args.load_path,
                                                f'{self.args.mean_label_name}_{y}.npy'),
                                    file_type='numpy')
                self.args.mean[y] = mean
            self.args.logger.print("Means for classes:", self.args.mean.keys())
            self.args.logger.print("Covs for classes:", self.args.cov.keys())
        except FileNotFoundError:
            self.args.logger.print(f"*** No MD for Task {task_id}***")

    def load_all_MD_stats(self, task_id):
        for p_task_id in range(task_id + 1):
            self.load_task_MD_stats(p_task_id)

    def preprocess_all_tasks(self, task_id):
        for p_task_id in range(task_id + 1):
            self.preprocess_task(p_task_id)

    def preprocess_task(self, task_id):
        self.current_train_loader = self.train_loaders[task_id]

        if self.preprocessed_task != task_id:
            self.args.logger.print(f"Preprocessing task {task_id}")
            inputs_preprocess = {
                'names': self.train_data.task_list[task_id][0],
                'labels': self.train_data.task_list[task_id][1],
                'task_id': task_id,
                'loader': self.current_train_loader
                }

            if hasattr(self.model, 'preprocess_task'):
                self.model.preprocess_task(**inputs_preprocess)
                self.args.logger.print(Counter(self.current_train_loader.dataset.targets))

            self.preprocessed_task = task_id

    def train_task(self, task_id):
        """
            Train a task. Contains preprocessing of model, post-processing, and testing.
        """
        # assert self.received_data == task_id

        self.preprocess_task(task_id)

        for epoch in range(self.init_epoch, self.args.n_epochs):
            inputs_observe = {
                'task_id': task_id,
                'B': len(self.current_train_loader),
                }

            iters, task_loss_list, cum_acc_list = [], [], []
            self.model.reset_eval()
            # orig is the original data (mostly likely numpy for CIFAR, and indices for ImageNet)
            for b, (idx, x, y, f_y, names, orig) in tqdm(enumerate(self.current_train_loader)):
                inputs_observe['b'] = b
                inputs_observe['indices'] = idx

                # for simplicity, consider that we know the labels ahead
                f_y = f_y[:, 1]
                x, y = x.to(self.args.device), y.to(self.args.device)

                x = self.inputs_to_features(x)

                if self.args.zero_shot:
                    raise NotImplementedError()

                loss = self.model.observe(x, y, names, x, f_y, **inputs_observe)

                if 'prompt_learning' in self.args.model and 'slda' in self.args.model and epoch == 0:
                    self.model.slda_observe(x, y, names, x, f_y, **inputs_observe)

                task_loss_list.append(loss)
                cum_acc_list.append(self.model.correct / self.model.total * 100)
                # iters.append(total_iter)
                # total_iter += 1
            # iter_list.append(iters)
            self.args.logger.print("Epoch {}/{} | Loss: {:.4f} | Acc: {:.2f}".format(epoch + 1,
                                                            self.args.n_epochs,
                                                            self.model.total_loss / len(self.current_train_loader),
                                                            self.model.correct / self.model.total * 100))

            if epoch == 0 and self.args.zero_shot:
                self.args.logger.print("Train Data | Task {}, Zero-shot Acc: {:.2f} | ".format(task_id, self.zeroshot.acc()['cil_acc']), end='')

            metrics = self.model.acc()

            if self.args.tr_dynamics:
                if self.args.dataset == 'timgnet':
                    raise NotImplementedError('margin not implemented for imagenet and timgnet')
                assert self.args.model == 'prompt_learning' # Only prompt_learning method has implemented AUM analysis
                margin = metrics['margin']
                indices = metrics['indices']

                self.current_train_loader.dataset.margin[indices] = self.current_train_loader.dataset.margin[indices] * (epoch + 1)
                self.current_train_loader.dataset.margin[indices] = (self.current_train_loader.dataset.margin[indices] + margin) / (epoch + 1)

            # Save features for MD statistics, use TRAIN data
            if (epoch + 1) % self.args.eval_every == 0 or (epoch + 1) == self.args.n_epochs:
                # If compute_md is true, obtain the features and compute/save the statistics for MD
                if self.args.compute_md:
                    self.model.compute_stats(task_id, self.train_loaders[task_id])

                if (epoch + 1) == self.args.n_epochs:
                    self.args.logger.print("End task...")

                    inputs_end_task = {
                        'cal_loaders': self.calibration_loaders[:task_id + 1],
                        'test_loaders': self.test_loaders[:task_id + 1],
                        'train_loader': self.train_loaders[task_id],
                        'task_id': task_id
                        }

                    if hasattr(self.model, 'end_task'):
                        self.model.end_task(**inputs_end_task)

                inputs_evaluate = {
                    'task_id': task_id,
                    'report_cil': True,
                    'total_learned_task_id': task_id,
                    'ensemble': False
                    }

                metrics = self.test_task(self.test_loaders[task_id], **inputs_evaluate)
                
                self.args.logger.print("Task {}, Epoch {}/{}, Total Loss: {:.4f}, CIL Acc: {:.2f}, TIL Acc: {:.2f}".format(task_id,
                                    epoch + 1, self.args.n_epochs, np.mean(task_loss_list),
                                    metrics['cil_acc'], metrics['til_acc']))

                # if compute_AUC is true, compute its AUC at eval_every
                if self.args.compute_auc:
                    self.test_auc(task_id, metrics, epoch=epoch)

                inputs_save = {
                    'task_id': task_id,
                    'epoch': epoch,
                    }
                self.save_all(**inputs_save)

        self.test_all(task_id, self.test_loaders[:task_id + 1])
        self.save_trackers()

    def train_all(self):
        for task_id in range(self.init_task, self.args.n_tasks):
            # self.receive_data(task_id)

            self.train_task(task_id)

    def test_all_all(self, task_id):
        """
            Test all from task 0 to task_id. Task task model must be saved.
        """
        for p_task_id in range(0, task_id + 1):
            self.load_task(p_task_id)
            self.test_all(p_task_id, self.test_loaders[:p_task_id + 1])

    def test_all(self, task_id, test_loaders):
        """
            At current task, test all the tasks until current task.
            This is different from test_all_all. test_all_all tests
            for all test_task <= current_task for all current_task <= self.args.n_tasks.
        """

        # Make sure that we are using the correct task dataset
        # assert self.received_data == task_id, \
        #     self.args.logger.print(f'Must be received_data ({self.received_data}) == task_id ({task_id})')

        self.args.logger.print("######################")
        true_lab, pred_lab = [], []
        for p_task_id, loader in enumerate(test_loaders):

            inputs_evaluate = {
                'task_id': p_task_id,
                'report_cil': True,
                'total_learned_task_id': len(test_loaders) - 1,
                'ensemble': False
                }

            metrics = self.test_task(loader, **inputs_evaluate)

            if self.args.save_output:
                save_output_path = os.path.join(self.args.logger.dir(), 'output_label')
                if not os.path.isdir(save_output_path):
                    os.mkdir(save_output_path)
                np.save(os.path.join(save_output_path,
                                    f"output_learned_{task_id}_task_{p_task_id}"),
                        metrics["output_list"])
                np.save(os.path.join(save_output_path,
                                    f"label_learned_{task_id}_task_{p_task_id}"),
                        metrics["label_list"])
                # np.save(self.args.logger.dir() + "output_learned_{}_task_{}".format(task_id, p_task_id),
                #                                         metrics["output_list"])
                # np.save(self.args.logger.dir() + "label_learned_{}_task_{}".format(task_id, p_task_id),
                #                                         metrics["label_list"])

            if self.args.tsne:
                tsne(np.concatenate(self.model.output_list),
                     np.concatenate(self.model.label_list),
                     logger=self.args.logger)
            if self.args.confusion:
                true_lab_ = np.concatenate(self.model.true_lab)
                pred_lab_ = np.concatenate(self.model.pred_lab)

                plot_confusion(true_lab_, pred_lab_, self.model.seen_names, task_id, p_task_id,
                                logger=self.args.logger, num_cls_per_task=self.args.num_cls_per_task)

                true_lab.append(true_lab_)
                pred_lab.append(pred_lab_)

            if self.args.confusion and p_task_id == len(test_loaders) - 1:
                true_lab_ = np.concatenate(true_lab)
                pred_lab_ = np.concatenate(pred_lab)
                plot_confusion(true_lab_, pred_lab_, self.model.seen_names,
                                name='confusion mat task {}'.format(p_task_id),
                                logger=self.args.logger, num_cls_per_task=self.args.num_cls_per_task)

            self.trackers['cil_tracker'].update(metrics['cil_acc'], task_id, p_task_id)
            self.trackers['til_tracker'].update(metrics['til_acc'], task_id, p_task_id)

        self.args.logger.print()
        if self.args.compute_auc:
            self.args.logger.print("Softmax AUC result")
            self.trackers['auc_softmax_tracker'].print_result(task_id, type='acc')
            self.args.logger.print("Open World result")
            self.trackers['openworld_softmax_tracker'].print_result(task_id, type='acc')
        if self.args.compute_md:
            self.args.logger.print("MD AUC result")
            self.trackers['auc_md_tracker'].print_result(task_id, type='acc')
        self.args.logger.print("CIL result")
        self.trackers['cil_tracker'].print_result(task_id, type='acc')
        self.trackers['cil_tracker'].print_result(task_id, type='forget')
        self.args.logger.print("TIL result")
        self.trackers['til_tracker'].print_result(task_id, type='acc')
        self.trackers['til_tracker'].print_result(task_id, type='forget')
        self.args.logger.print()
        self.args.logger.now()

