import logging
import numpy as np
import torch
import copy
from torch import nn
from tqdm import tqdm
from torch import optim
from torch.optim import Optimizer
import math
from torch.nn import functional as F
from torch.utils.data import DataLoader
from utils_online.inc_net_online import MVP
from models_online.base import BaseLearner
from buffer.buffer import ProtoBuffer, Reservoir, QueryReservoir
from utils.toolkit import tensor2numpy
from utils_online.si_blurry import IndexedDataset, OnlineSampler, OnlineTestSampler
import wandb
import sys

# tune the model at first session with vpt, and then conduct simple shot.
num_workers = 8

class Learner(BaseLearner):
    def __init__(self, args):
        super().__init__(args)
    
        self._network = MVP(num_classes=self.args['nb_classes'])

        self.batch_size = args["batch_size"]
        self.init_lr = args["init_lr"]
        self.weight_decay = args["weight_decay"] if args["weight_decay"] is not None else 0.0005
        self.min_lr = args["min_lr"] if args["min_lr"] is not None else 1e-8
        self.use_mask = args["use_mask"]
        self.mask = torch.zeros(args["nb_classes"], device=self._device) - torch.inf
        self.exposed_classes = []
        self.args = args
        self.optim = None
        self.scheduler = None
        self.use_afs = True
        self.use_gsf = True
        self.alpha = 0.5    
        self.gamma = 2.
        # FGH stuff
        self.beta1 = 0.9
        self.beta2 = 0.999
        self.step = 0
        
        self.grad_weight = {}
        self.old_grad = {}
        self.m = {}
        self.v = {}
                
        self.buffer = ProtoBuffer(
            max_size=args['nb_classes'],
            shape=(768,),
            device=self._device
        )
        
        total_params = sum(p.numel() for p in self._network.parameters())
        logging.info(f'{total_params:,} total parameters.')


    def after_task(self):
        self._known_classes = self._total_classes

    def incremental_train(self, data_manager, **kwargs):
        n_tasks = kwargs.get('n_tasks', 10)
        self._cur_task += 1

        # if self._cur_task > 0:
        #     try:
        #         if self._network.module.prompt is not None:
        #             self._network.module.prompt.process_task_count()
        #     except:
        #         if self._network.prompt is not None:
        #             self._network.prompt.process_task_count()
        if self.args['blurry']:
            self._total_classes = self.args['nb_classes']
            train_dataset = data_manager.get_dataset(np.arange(0, self._total_classes),source="train", mode="train")
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
        else:
            self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task)
            # self._network.update_fc(self._total_classes)
            logging.info("Learning on {}-{}".format(self._known_classes, self._total_classes))

            train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),source="train", mode="train")
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
        self.test_dataset = test_dataset
        self.train_dataset = train_dataset
        if self.args['blurry']:
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
            
            self.train_dataset = IndexedDataset(self.train_dataset)
            self.train_sampler = OnlineSampler(self.train_dataset, n_tasks, 10, 50, self.args['seed'], False, 1)
            self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, sampler=self.train_sampler, num_workers=num_workers, pin_memory=True)
            
            self.test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
            self.train_sampler.set_task(self._cur_task)
        else:
            self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=num_workers)
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
            self.test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, drop_last=False, num_workers=num_workers)
            
        self.data_manager = data_manager

        if len(self._multiple_gpus) > 1:
            print('Multiple GPUs')
            self._network = nn.DataParallel(self._network, self._multiple_gpus)
        self._train(self.train_loader, self.test_loader)
        if len(self._multiple_gpus) > 1:
            self._network = self._network.module

    def _train(self, train_loader, test_loader):
        self._network.to(self._device)

        if self._cur_task == 0:
            self.optim = self.get_optimizer()
            self.scheduler = self.get_scheduler()

        # self.data_weighting()
        self._init_train(train_loader, test_loader)

    def get_optimizer(self):
        params = self._network.parameters()
                        
        if self.args['optimizer'] == 'sgd':
            optimizer = optim.SGD(params, momentum=0.9, lr=self.init_lr,weight_decay=self.weight_decay)
        elif self.args['optimizer'] == 'adam':
            optimizer = optim.Adam(params, lr=self.init_lr, weight_decay=self.weight_decay)
        elif self.args['optimizer'] == 'adamw':
            optimizer = optim.AdamW(params, lr=self.init_lr, weight_decay=self.weight_decay)
        
        return optimizer

    def get_scheduler(self):
        if self.args["scheduler"] == 'cosine':
            scheduler = CosineSchedule(self.optim, K=self.args["tuned_epoch"])
        elif self.args["scheduler"] == 'steplr':
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer=self.optim, milestones=self.args["init_milestones"], gamma=self.args["init_lr_decay"])
        elif self.args["scheduler"] == 'constant':
            scheduler = None

        return scheduler

    def random_lr_change(self):
        new_lr = torch.rand(1).item() * 0.01 + 0.00001  # Randomly select a new learning rate between 0.01 and 0.11
        for param_group in self.optim.param_groups:
            param_group['lr'] = new_lr
                
    def _get_ignore(self, sample_grad, batch_grad):
        ign_score = (1. - torch.cosine_similarity(sample_grad, batch_grad, dim=1))#B
        return ign_score

    def _get_compensation(self, y, feat):
        head_w = self._network.backbone.head.weight[y].clone().detach()
        cps_score = (1. - torch.cosine_similarity(head_w, feat, dim=1) + 0.5)#B
        return cps_score
    
    def _get_score(self, feat, y, mask):
        sample_grad, batch_grad = self._compute_grads(feat, y, mask)
        ign_score = self._get_ignore(sample_grad, batch_grad)
        cps_score = self._get_compensation(y, feat)
        return ign_score, cps_score
    
    def _compute_grads(self, feature, y, mask):
        head = copy.deepcopy(self._network.backbone.head)
        head.zero_grad()
        logit = head(feature.detach())
        if self.use_mask:
            logit = logit * mask.clone().detach()
        logit = logit + self.mask
        
        sample_loss = F.cross_entropy(logit, y, reduction='none')
        sample_grad = []
        for idx in range(len(y)):
            sample_loss[idx].backward(retain_graph=True)
            _g = head.weight.grad[y[idx]].clone()
            sample_grad.append(_g)
            head.zero_grad()
        sample_grad = torch.stack(sample_grad)    #B,dim
        
        head.zero_grad()
        batch_loss = F.cross_entropy(logit, y, reduction='mean')
        batch_loss.backward(retain_graph=True)
        total_batch_grad = head.weight.grad[:len(self.exposed_classes)].clone()  # C,dim
        idx = torch.arange(len(y))
        batch_grad = total_batch_grad[y[idx]]    #B,dim
        
        return sample_grad, batch_grad
    
    def loss_fn(self, feature, mask, y):
        ign_score, cps_score = self._get_score(feature.detach(), y, mask)

        if self.use_afs:
            logit = self._network.forward_head(feature)
            logit = self._network.forward_head(feature / (cps_score.unsqueeze(1)))
        else:
            logit = self._network.forward_head(feature)
        if self.use_mask:
            logit = logit * mask
        logit = logit + self.mask
        log_p = F.log_softmax(logit, dim=1)
        loss = F.nll_loss(log_p, y)
        if self.use_gsf:
            loss = (1-self.alpha)* loss + self.alpha * (ign_score ** self.gamma) * loss
        return loss.mean() + self._network.get_similarity_loss()
    
    def model_forward(self, x, y):
        feature, mask = self._network.forward_features(x)
        logit = self._network.forward_head(feature)
        if self.use_mask:
            logit = logit * mask
        logit = logit + self.mask
        loss = self.loss_fn(feature, mask, y)
        return logit, loss, feature

    def _init_train(self, train_loader, test_loader):
        prog_bar = tqdm(range(self.args['tuned_epoch']))
        for _, epoch in enumerate(prog_bar):
            self._network.train()

            losses = 0.0
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                self.add_new_class(targets)
                x, y = inputs, targets
                
                x = x.to(self._device)
                y = y.to(self._device)
                
                for j in range(len(y)):
                    y[j] = self.exposed_classes.index(y[j].item())

                logit, loss_vanilla, f = self.model_forward(x, y)

                loss = loss_vanilla

                # LR monitoring
                lr = self.optim.param_groups[0]['lr']
                wandb.log({'lr': lr, 'step': self.step})
                
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
                losses += loss.item()
                
                if str(loss.item()) == 'nan':
                    print("LOSS IS NAN")
                    sys.exit(0)
                
                wandb.log({
                    'loss': loss.item(),
                    'step': self.step
                    })

                # buffer update
                self.buffer.update(queries=None, keys=None, values=None, labels=targets.detach(), features=f.detach())
                
                _, preds = torch.max(logit, dim=1)
                correct += preds.eq(y.expand_as(preds)).cpu().sum()
                total += len(targets)

            if self.scheduler:
                self.scheduler.step()
            
            train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
            if self.args['blurry']:
                _, _, _, mem_y, _, _ = self.buffer.random_retrieve(n_imgs=self.args['nb_classes'])
                test_sampler = OnlineTestSampler(self.test_dataset, mem_y.unique())
                test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, sampler=test_sampler, num_workers=num_workers)
                self.test_loader = test_loader
            if (epoch + 1) % 5 == 0:
                test_acc = self._compute_accuracy(self._network, test_loader)
                info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
                    self._cur_task,
                    epoch + 1,
                    self.args['tuned_epoch'],
                    losses / len(train_loader),
                    train_acc,
                    test_acc,
                )
            else:
                info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
                    self._cur_task,
                    epoch + 1,
                    self.args['tuned_epoch'],
                    losses / len(train_loader),
                    train_acc,
                )
            prog_bar.set_description(info)

        logging.info(info)
    
    def add_new_class(self, class_name):
        for label in class_name:
            if label.item() not in self.exposed_classes:
                self.exposed_classes.append(label.item())
        # if self.distributed:
        #     exposed_classes = torch.cat(self.all_gather(torch.tensor(self.exposed_classes, device=self.device))).cpu().tolist()
        #     self.exposed_classes = []
        #     for cls in exposed_classes:
        #         if cls not in self.exposed_classes:
        #             self.exposed_classes.append(cls)
        # self.memory.add_new_class(cls_list=self.exposed_classes)
        self.mask[:len(self.exposed_classes)] = 0
        # if 'reset' in self.sched_name:
            # self.update_schedule(reset=True)

    def _eval_cnn(self, loader):
        self._network.eval()
        y_pred, y_true = [], []
        for _, (_, inputs, targets) in enumerate(loader):
            inputs = inputs.to(self._device)
            for j in range(len(targets)):
                targets[j] = self.exposed_classes.index(targets[j].item())
            with torch.no_grad():
                logit = self._network(inputs)
                logit = logit + self.mask
                
            predicts = torch.topk(
                logit, k=self.topk, dim=1, largest=True, sorted=True
            )[
                1
            ]  # [bs, topk]
            y_pred.append(predicts.cpu().numpy())
            y_true.append(targets.cpu().numpy())

        return np.concatenate(y_pred), np.concatenate(y_true)  # [N, topk]

    def _compute_accuracy(self, model, loader):
        model.eval()
        correct, total = 0, 0
        for i, (_, inputs, targets) in enumerate(loader):
            inputs = inputs.to(self._device)
            for j in range(len(targets)):
                targets[j] = self.exposed_classes.index(targets[j].item())
            with torch.no_grad():
                outputs = model(inputs)[:, :self._total_classes]
            predicts = torch.max(outputs, dim=1)[1]
            correct += (predicts.cpu() == targets).sum()
            total += len(targets)

        return np.around(tensor2numpy(correct) * 100 / total, decimals=2)


class _LRScheduler(object):
    def __init__(self, optimizer, last_epoch=-1):
        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer
        if last_epoch == -1:
            for group in optimizer.param_groups:
                group.setdefault('initial_lr', group['lr'])
        else:
            for i, group in enumerate(optimizer.param_groups):
                if 'initial_lr' not in group:
                    raise KeyError("param 'initial_lr' is not specified "
                                   "in param_groups[{}] when resuming an optimizer".format(i))
        self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
        self.step(last_epoch + 1)
        self.last_epoch = last_epoch

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.
        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        """
        return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}

    def load_state_dict(self, state_dict):
        """Loads the schedulers state.
        Arguments:
            state_dict (dict): scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        self.__dict__.update(state_dict)

    def get_lr(self):
        raise NotImplementedError

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
        self.last_epoch = epoch
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

class CosineSchedule(_LRScheduler):

    def __init__(self, optimizer, K):
        self.K = K
        super().__init__(optimizer, -1)

    def cosine(self, base_lr):
        return base_lr * math.cos((99 * math.pi * (max(self.last_epoch, 1))) / (200 * (max(self.K-1, 1))))

    def get_lr(self):
        return [self.cosine(base_lr) for base_lr in self.base_lrs]