import logging
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
from torch import optim
from torch.optim import Optimizer
import math
from torch.nn import functional as F
from torch.utils.data import DataLoader
from utils_online.inc_net_online import CodaPromptVitNetOnline
from models_online.base import BaseLearner
from buffer.buffer import ProtoBuffer, Reservoir, QueryReservoir
from utils.toolkit import tensor2numpy
from utils_online.si_blurry import IndexedDataset, OnlineSampler, OnlineTestSampler
import wandb
import sys

# tune the model at first session with vpt, and then conduct simple shot.
num_workers = 8

class Learner(BaseLearner):
    def __init__(self, args):
        super().__init__(args)
    
        self._network = CodaPromptVitNetOnline(args, True)

        self.batch_size = args["batch_size"]
        self.init_lr = args["init_lr"]
        self.weight_decay = args["weight_decay"] if args["weight_decay"] is not None else 0.0005
        self.min_lr = args["min_lr"] if args["min_lr"] is not None else 1e-8
        self.recab_coef = args["recab_coef"]
        self.args = args
        self.optim = None
        self.scheduler = None
        # FGH stuff
        self.beta1 = 0.9
        self.beta2 = 0.999
        self.step = 0
        
        self.grad_weight = {}
        self.old_grad = {}
        self.m = {}
        self.v = {}
                
        self.buffer = ProtoBuffer(
            max_size=args['nb_classes'],
            shape=(768,),
            device=self._device
        )
        
        total_params = sum(p.numel() for p in self._network.parameters())
        logging.info(f'{total_params:,} total parameters.')
        total_trainable_params = sum(p.numel() for p in self._network.fc.parameters() if p.requires_grad) + sum(p.numel() for p in self._network.prompt.parameters() if p.requires_grad)
        logging.info(f'{total_trainable_params:,} fc and prompt training parameters.')


    def after_task(self):
        self._known_classes = self._total_classes
        # pass

    def incremental_train(self, data_manager, **kwargs):
        n_tasks = kwargs.get('n_tasks', 10)
        self._cur_task += 1

        # if self._cur_task > 0:
        #     try:
        #         if self._network.module.prompt is not None:
        #             self._network.module.prompt.process_task_count()
        #     except:
        #         if self._network.prompt is not None:
        #             self._network.prompt.process_task_count()
        if self.args['blurry']:
            self._total_classes = self.args['nb_classes']
            train_dataset = data_manager.get_dataset(np.arange(0, self._total_classes),source="train", mode="train")
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
        else:
            self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task)
            # self._network.update_fc(self._total_classes)
            logging.info("Learning on {}-{}".format(self._known_classes, self._total_classes))

            train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),source="train", mode="train")
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
        self.test_dataset = test_dataset
        self.train_dataset = train_dataset
        if self.args['blurry']:
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
            
            self.train_dataset = IndexedDataset(self.train_dataset)
            self.train_sampler = OnlineSampler(self.train_dataset, n_tasks, 10, 50, self.args['seed'], False, 1)
            self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, sampler=self.train_sampler, num_workers=num_workers, pin_memory=True)
            
            self.test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
            self.train_sampler.set_task(self._cur_task)
        else:
            self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=num_workers)
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
            self.test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, drop_last=False, num_workers=num_workers)
            
        self.data_manager = data_manager

        if len(self._multiple_gpus) > 1:
            print('Multiple GPUs')
            self._network = nn.DataParallel(self._network, self._multiple_gpus)
        self._train(self.train_loader, self.test_loader)
        if len(self._multiple_gpus) > 1:
            self._network = self._network.module

    def _train(self, train_loader, test_loader):
        self._network.to(self._device)

        if self._cur_task == 0:
            self.optim = self.get_optimizer()
            self.scheduler = self.get_scheduler()

        # self.data_weighting()
        self._init_train(train_loader, test_loader)

    def data_weighting(self):
        self.dw_k = torch.tensor(np.ones(self._total_classes + 1, dtype=np.float32))
        self.dw_k = self.dw_k.to(self._device)

    def get_optimizer(self):
        if len(self._multiple_gpus) > 1:
            params = list(self._network.module.prompt.parameters()) + list(self._network.module.fc.parameters())
        else:
            params = list(self._network.prompt.parameters()) + list(self._network.fc.parameters())
        if self.args['optimizer'] == 'sgd':
            optimizer = optim.SGD(params, momentum=0.9, lr=self.init_lr,weight_decay=self.weight_decay)
        elif self.args['optimizer'] == 'adam':
            optimizer = optim.Adam(params, lr=self.init_lr, weight_decay=self.weight_decay)
        elif self.args['optimizer'] == 'adamw':
            optimizer = optim.AdamW(params, lr=self.init_lr, weight_decay=self.weight_decay)
        return optimizer

    def get_scheduler(self):
        if self.args["scheduler"] == 'cosine':
            scheduler = CosineSchedule(self.optim, K=self.args["tuned_epoch"])
        elif self.args["scheduler"] == 'steplr':
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer=self.optim, milestones=self.args["init_milestones"], gamma=self.args["init_lr_decay"])
        elif self.args["scheduler"] == 'constant':
            scheduler = None

        return scheduler

    def random_lr_change(self):
        new_lr = torch.rand(1).item() * 0.01 + 0.00001  # Randomly select a new learning rate between 0.01 and 0.11
        for param_group in self.optim.param_groups:
            param_group['lr'] = new_lr

    def _init_train(self, train_loader, test_loader):
        prog_bar = tqdm(range(self.args['tuned_epoch']))
        gradall = np.zeros(self.args['nb_classes'], dtype=float)
        for _, epoch in enumerate(prog_bar):
            self._network.train()

            losses = 0.0
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(self._device)
                
                # logits
                logits, prompt_loss, q, p, v, f = self._network(inputs, train=True)
                logits = logits[:, :self._total_classes]
                present = targets.unique().to(self._device)

                logits[:, [i for i in range(self._total_classes) if i not in present]] = float('-inf')  # NEW WAY

                mem_y, mem_f = self.buffer.random_retrieve_f(n_imgs=self.args['nb_classes'])
                
                loss_recab = torch.tensor(0.0).to(self._device)
                
                # fc recalibration
                if len(mem_y) > 0 and self.recab_coef > 0:
                    logits_ori = self._network.fc(mem_f)
                    logits_ori[:, [i for i in range(self.args['nb_classes']) if i not in mem_y.unique().long()]] = float('-inf')
                    logits_ori = logits_ori[:, :self._total_classes]
                    loss_recab = F.cross_entropy(logits_ori, mem_y.long()).mean()
                
                loss_supervised = \
                    (F.cross_entropy(logits, targets.long())).mean() + self.recab_coef * loss_recab

                # ce loss
                loss = loss_supervised + prompt_loss.sum()
                
                # LR monitoring
                lr = self.optim.param_groups[0]['lr']
                wandb.log({'lr': lr, 'step': self.step})
                
                self.optim.zero_grad()
                loss.backward()
                
                # gradient reweighting
                for i, (name, param) in enumerate(self._network.named_parameters()):
                    curr_grad = param.grad
                    if curr_grad is not None:
                        if str(i) in self.grad_weight.keys():
                            self.m[str(i)] = self.beta1 * self.m[str(i)] + (1 - self.beta1) * curr_grad
                            self.v[str(i)] = self.beta2 * self.v[str(i)] + (1 - self.beta2) * curr_grad ** 2
                            m_hat = self.m[str(i)] / (1 - self.beta1 ** self.step)
                            v_hat = self.v[str(i)] / (1 - self.beta2 ** self.step)
                            curr_grad = m_hat / (torch.sqrt(v_hat) + 1e-8)
                            
                            self.grad_weight[str(i)] = self.grad_weight[str(i)] + self.args['gamma'] * curr_grad * self.old_grad[str(i)]
                            self.grad_weight[str(i)] = torch.clamp(self.grad_weight[str(i)], 0, self.args['clamp'])

                            param.grad = self.grad_weight[str(i)] * param.grad
                            
                            if name != "fc.weight" and name != "fc.bias":
                                wandb.log(
                                    {
                                    f"grad_w_{name}": self.grad_weight[str(i)].mean().item(),
                                    f"grad_{name}": curr_grad.norm().item(),
                                    f"grad_p_{name}": param.grad.norm().item(),
                                    "step": self.step
                                    }
                                            )
                            else:
                                    wandb.log({
                                        f"grad_p_{name}_0-10":   param.grad[0:10].norm().item(),
                                        f"grad_p_{name}_10-20":  param.grad[10:20].norm().item(),
                                        f"grad_p_{name}_20-30":  param.grad[20:30].norm().item(),
                                        f"grad_p_{name}_30-40":  param.grad[30:40].norm().item(),
                                        f"grad_p_{name}_40-50":  param.grad[40:50].norm().item(),
                                        f"grad_p_{name}_50-60":  param.grad[50:60].norm().item(),
                                        f"grad_p_{name}_60-70":  param.grad[60:70].norm().item(),
                                        f"grad_p_{name}_70-80":  param.grad[70:80].norm().item(),
                                        f"grad_p_{name}_80-90":  param.grad[80:90].norm().item(),
                                        f"grad_p_{name}_90-100": param.grad[90:100].norm().item(),
                                        "step": self.step
                                    })
                        else:
                            self.grad_weight[str(i)] = 1.0
                            self.m[str(i)] = 0.0
                            self.v[str(i)] = 0.0
                    self.old_grad[str(i)] = curr_grad
                self.step += 1
                
                self.optim.step()
                losses += loss.item()
                
                if str(loss.item()) == 'nan':
                    print('loss is nan')
                    sys.exit(0)
                
                wandb.log({
                    'loss': loss.item(),
                    'loss_supervised': loss_supervised.item(),
                    'loss_recab': loss_recab.item(),
                    'step': self.step
                    })

                # buffer update
                self.buffer.update(queries=None, keys=None, values=None, labels=targets.detach(), features=f.detach())
                
                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            if self.scheduler:
                self.scheduler.step()
            
            train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
            if self.args['blurry']:
                _, _, _, mem_y, _, _ = self.buffer.random_retrieve(n_imgs=self.args['nb_classes'])
                test_sampler = OnlineTestSampler(self.test_dataset, mem_y.unique())
                test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, sampler=test_sampler, num_workers=num_workers)
                self.test_loader = test_loader
            if (epoch + 1) % 5 == 0:
                test_acc = self._compute_accuracy(self._network, test_loader)
                info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
                    self._cur_task,
                    epoch + 1,
                    self.args['tuned_epoch'],
                    losses / len(train_loader),
                    train_acc,
                    test_acc,
                )
            else:
                info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
                    self._cur_task,
                    epoch + 1,
                    self.args['tuned_epoch'],
                    losses / len(train_loader),
                    train_acc,
                )
            prog_bar.set_description(info)

        logging.info(info)

    def _eval_cnn(self, loader):
        self._network.eval()
        y_pred, y_true = [], []
        for _, (_, inputs, targets) in enumerate(loader):
            inputs = inputs.to(self._device)
            with torch.no_grad():
                outputs = self._network(inputs)[:, :self._total_classes]
            predicts = torch.topk(
                outputs, k=self.topk, dim=1, largest=True, sorted=True
            )[
                1
            ]  # [bs, topk]
            y_pred.append(predicts.cpu().numpy())
            y_true.append(targets.cpu().numpy())

        return np.concatenate(y_pred), np.concatenate(y_true)  # [N, topk]

    def _compute_accuracy(self, model, loader):
        model.eval()
        correct, total = 0, 0
        for i, (_, inputs, targets) in enumerate(loader):
            inputs = inputs.to(self._device)
            with torch.no_grad():
                outputs = model(inputs)[:, :self._total_classes]
            predicts = torch.max(outputs, dim=1)[1]
            correct += (predicts.cpu() == targets).sum()
            total += len(targets)

        return np.around(tensor2numpy(correct) * 100 / total, decimals=2)


class _LRScheduler(object):
    def __init__(self, optimizer, last_epoch=-1):
        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer
        if last_epoch == -1:
            for group in optimizer.param_groups:
                group.setdefault('initial_lr', group['lr'])
        else:
            for i, group in enumerate(optimizer.param_groups):
                if 'initial_lr' not in group:
                    raise KeyError("param 'initial_lr' is not specified "
                                   "in param_groups[{}] when resuming an optimizer".format(i))
        self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
        self.step(last_epoch + 1)
        self.last_epoch = last_epoch

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.
        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        """
        return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}

    def load_state_dict(self, state_dict):
        """Loads the schedulers state.
        Arguments:
            state_dict (dict): scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        self.__dict__.update(state_dict)

    def get_lr(self):
        raise NotImplementedError

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
        self.last_epoch = epoch
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

class CosineSchedule(_LRScheduler):

    def __init__(self, optimizer, K):
        self.K = K
        super().__init__(optimizer, -1)

    def cosine(self, base_lr):
        return base_lr * math.cos((99 * math.pi * (max(self.last_epoch, 1))) / (200 * (max(self.K-1, 1))))

    def get_lr(self):
        return [self.cosine(base_lr) for base_lr in self.base_lrs]