import logging
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from utils_online.inc_net_online import PromptVitNetOnline
from models_online.base import BaseLearner
from utils.toolkit import tensor2numpy
from buffer.buffer import ProtoBuffer, Reservoir, QueryReservoir
from utils_online.si_blurry import IndexedDataset, OnlineSampler, OnlineTestSampler
import wandb
import sys

# tune the model at first session with vpt, and then conduct simple shot.
num_workers = 8

class Learner(BaseLearner):
    def __init__(self, args):
        super().__init__(args)
    
        self._network = PromptVitNetOnline(args, True)

        self.batch_size = args["batch_size"]
        self.init_lr = args["init_lr"]
        self.weight_decay = args["weight_decay"] if args["weight_decay"] is not None else 0.0005
        self.min_lr = args["min_lr"] if args["min_lr"] is not None else 1e-8
        self.recab_coef = args["recab_coef"]
        self.args = args
        
        self.optim = None
        self.scheduler = None
        # FGH stuff
        self.beta1 = 0.9
        self.beta2 = 0.999
        self.step = 0
        
        self.grad_weight = {}
        self.old_grad = {}
        self.m = {}
        self.v = {}
                
        self.buffer = ProtoBuffer(
            max_size=args['nb_classes'],
            shape=(768,),
            device=self._device
        )

        # Freeze the parameters for ViT.
        if self.args["freeze"]:
            for p in self._network.original_backbone.parameters():
                p.requires_grad = False
        
            # freeze args.freeze[blocks, patch_embed, cls_token] parameters
            for n, p in self._network.backbone.named_parameters():
                if n.startswith(tuple(self.args["freeze"])):
                    p.requires_grad = False
        
        total_params = sum(p.numel() for p in self._network.backbone.parameters())
        logging.info(f'{total_params:,} model total parameters.')
        total_trainable_params = sum(p.numel() for p in self._network.backbone.parameters() if p.requires_grad)
        logging.info(f'{total_trainable_params:,} model training parameters.')

        # if some parameters are trainable, print the key name and corresponding parameter number
        if total_params != total_trainable_params:
            for name, param in self._network.backbone.named_parameters():
                if param.requires_grad:
                    logging.info("{}: {}".format(name, param.numel()))

    def after_task(self):
        self._known_classes = self._total_classes

    def incremental_train(self, data_manager, **kwargs):
        n_tasks = kwargs.get('n_tasks', 10)
        self._cur_task += 1
        
        if self.args['blurry']:
            self._total_classes = self.args['nb_classes']
            train_dataset = data_manager.get_dataset(np.arange(0, self._total_classes),source="train", mode="train")
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
        else:
            self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task)
            # self._network.update_fc(self._total_classes)
            logging.info("Learning on {}-{}".format(self._known_classes, self._total_classes))

            train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),source="train", mode="train")
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
        self.test_dataset = test_dataset
        self.train_dataset = train_dataset
        if self.args['blurry']:
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
            
            self.train_dataset = IndexedDataset(self.train_dataset)
            self.train_sampler = OnlineSampler(self.train_dataset, n_tasks, 10, 50, self.args['seed'], False, 1)
            self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, sampler=self.train_sampler, num_workers=num_workers, pin_memory=True)
            
            self.test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
            self.train_sampler.set_task(self._cur_task)
        else:
            self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=num_workers)
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
            self.test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, drop_last=False, num_workers=num_workers)
            
        self.data_manager = data_manager
        
        if len(self._multiple_gpus) > 1:
            print('Multiple GPUs')
            self._network = nn.DataParallel(self._network, self._multiple_gpus)
        self._train(self.train_loader, self.test_loader)
        if len(self._multiple_gpus) > 1:
            self._network = self._network.module

    def _train(self, train_loader, test_loader):
        self._network.to(self._device)

        if self._cur_task == 0:
            self.optim = self.get_optimizer()
            self.scheduler = self.get_scheduler()
            
        # if self._cur_task > 0:
        #     self._init_prompt(optimizer)

        # if self._cur_task > 0 and self.args["reinit_optimizer"]:
        #     optimizer = self.get_optimizer()
        
        self._init_train(train_loader, test_loader)

    def get_optimizer(self):
        if self.args['optimizer'] == 'sgd':
            optimizer = optim.SGD(
                filter(lambda p: p.requires_grad, self._network.parameters()), 
                momentum=0.9, 
                lr=self.init_lr,
                weight_decay=self.weight_decay
            )
        elif self.args['optimizer'] == 'adam':
            optimizer = optim.Adam(
                filter(lambda p: p.requires_grad, self._network.parameters()),
                lr=self.init_lr, 
                weight_decay=self.weight_decay
            )
        elif self.args['optimizer'] == 'adamw':
            optimizer = optim.AdamW(
                filter(lambda p: p.requires_grad, self._network.parameters()),
                lr=self.init_lr, 
                weight_decay=self.weight_decay
            )

        return optimizer
    
    def get_scheduler(self):
        if self.args["scheduler"] == 'cosine':
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=self.optim, T_max=self.args['tuned_epoch'], eta_min=self.min_lr)
        elif self.args["scheduler"] == 'steplr':
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer=self.optim, milestones=self.args["init_milestones"], gamma=self.args["init_lr_decay"])
        elif self.args["scheduler"] == 'constant':
            scheduler = None

        return scheduler

    def _init_prompt(self, optimizer):
        args = self.args
        model = self._network.backbone
        task_id = self._cur_task

        # Transfer previous learned prompt params to the new prompt
        if args["prompt_pool"] and args["shared_prompt_pool"]:
            prev_start = (task_id - 1) * args["top_k"]
            prev_end = task_id * args["top_k"]

            cur_start = prev_end
            cur_end = (task_id + 1) * args["top_k"]

            if (prev_end > args["size"]) or (cur_end > args["size"]):
                pass
            else:
                cur_idx = (slice(None), slice(None), slice(cur_start, cur_end)) if args["use_prefix_tune_for_e_prompt"] else (slice(None), slice(cur_start, cur_end))
                prev_idx = (slice(None), slice(None), slice(prev_start, prev_end)) if args["use_prefix_tune_for_e_prompt"] else (slice(None), slice(prev_start, prev_end))

                with torch.no_grad():
                    model.e_prompt.prompt.grad.zero_()
                    model.e_prompt.prompt[cur_idx] = model.e_prompt.prompt[prev_idx]
                    optimizer.param_groups[0]['params'] = model.parameters()
                
        # Transfer previous learned prompt param keys to the new prompt
        if args["prompt_pool"] and args["shared_prompt_key"]:
            prev_start = (task_id - 1) * args["top_k"]
            prev_end = task_id * args["top_k"]

            cur_start = prev_end
            cur_end = (task_id + 1) * args["top_k"]

            if (prev_end > args["size"]) or (cur_end > args["size"]):
                pass
            else:
                cur_idx = (slice(cur_start, cur_end))
                prev_idx = (slice(prev_start, prev_end))

            with torch.no_grad():
                model.e_prompt.prompt_key.grad.zero_()
                model.e_prompt.prompt_key[cur_idx] = model.e_prompt.prompt_key[prev_idx]
                optimizer.param_groups[0]['params'] = model.parameters()

    def _init_train(self, train_loader, test_loader):
        prog_bar = tqdm(range(self.args['tuned_epoch']))
        for _, epoch in enumerate(prog_bar):
            self._network.backbone.train()
            self._network.original_backbone.eval()

            losses = 0.0
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(self._device)

                output = self._network(inputs, task_id=-1, train=True)
                logits = output["logits"][:, :self._total_classes]
                f = output['pre_logits']
                
                present = targets.unique().to(self._device)
                logits[:, [i for i in range(self._total_classes) if i not in present]] = float('-inf')
                
                mem_y, mem_f = self.buffer.random_retrieve_f(n_imgs=self.args['nb_classes'])

                loss_recab = torch.tensor([0]).to(self._device)
                
                # fc recalibration
                if len(mem_y) > 0 and self.recab_coef > 0:
                    logits_ori = self._network.forward_head(mem_f)
                    logits_ori[:, [i for i in range(self.args['nb_classes']) if i not in mem_y.unique().long()]] = float('-inf')
                    logits_ori = logits_ori[:, :self._total_classes]
                    loss_recab = F.cross_entropy(logits_ori, mem_y.long()).mean()
                
                loss = \
                    (F.cross_entropy(logits, targets.long())).mean() + self.recab_coef * loss_recab
                
                if self.args["pull_constraint"] and 'reduce_sim' in output:
                    loss = loss - self.args["pull_constraint_coeff"] * output['reduce_sim']

                self.optim.zero_grad()
                loss.backward()
                
                # gradient reweighting
                for i, (name, param) in enumerate(self._network.named_parameters()):
                    curr_grad = param.grad
                    if curr_grad is not None:
                        if str(i) in self.grad_weight.keys():
                            self.m[str(i)] = self.beta1 * self.m[str(i)] + (1 - self.beta1) * curr_grad
                            self.v[str(i)] = self.beta2 * self.v[str(i)] + (1 - self.beta2) * curr_grad ** 2
                            m_hat = self.m[str(i)] / (1 - self.beta1 ** self.step)
                            v_hat = self.v[str(i)] / (1 - self.beta2 ** self.step)
                            curr_grad = m_hat / (torch.sqrt(v_hat) + 1e-8)
                            
                            self.grad_weight[str(i)] = self.grad_weight[str(i)] + self.args['gamma'] * curr_grad * self.old_grad[str(i)]
                            self.grad_weight[str(i)] = torch.clamp(self.grad_weight[str(i)], 0, self.args['clamp'])

                            param.grad = self.grad_weight[str(i)] * param.grad
                            
                            if name != "backbone.head.weight" and name != "backbone.head.bias":
                                wandb.log(
                                    {
                                    f"grad_w_{name}": self.grad_weight[str(i)].mean().item(),
                                    f"grad_{name}": curr_grad.norm().item(),
                                    f"grad_p_{name}": param.grad.norm().item(),
                                    "step": self.step
                                    }
                                            )
                            else:
                                if "weight" in name:
                                    name = "fc.weight"
                                else:
                                    name = "fc.bias"
                                wandb.log({
                                    f"grad_p_{name}_0-10":   param.grad[0:10].norm().item(),
                                    f"grad_p_{name}_10-20":  param.grad[10:20].norm().item(),
                                    f"grad_p_{name}_20-30":  param.grad[20:30].norm().item(),
                                    f"grad_p_{name}_30-40":  param.grad[30:40].norm().item(),
                                    f"grad_p_{name}_40-50":  param.grad[40:50].norm().item(),
                                    f"grad_p_{name}_50-60":  param.grad[50:60].norm().item(),
                                    f"grad_p_{name}_60-70":  param.grad[60:70].norm().item(),
                                    f"grad_p_{name}_70-80":  param.grad[70:80].norm().item(),
                                    f"grad_p_{name}_80-90":  param.grad[80:90].norm().item(),
                                    f"grad_p_{name}_90-100": param.grad[90:100].norm().item(),
                                    "step": self.step
                                })
                        else:
                            self.grad_weight[str(i)] = 1.0
                            self.m[str(i)] = 0.0
                            self.v[str(i)] = 0.0
                    self.old_grad[str(i)] = curr_grad
                self.step += 1
                
                self.optim.step()
                losses += loss.item()
                
                if str(loss.item()) == 'nan':
                    sys.exit(0)
                
                wandb.log({
                    'loss': loss.item(),
                    'loss_recab': loss_recab.item(),
                    'step': self.step
                    })
                
                # buffer update
                self.buffer.update(queries=None, keys=None, values=None, labels=targets.detach(), features=f.detach())

                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            if self.scheduler:
                self.scheduler.step()
                
            train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
            if self.args['blurry']:
                _, _, _, mem_y, _, _ = self.buffer.random_retrieve(n_imgs=self.args['nb_classes'])
                test_sampler = OnlineTestSampler(self.test_dataset, mem_y.unique())
                test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, sampler=test_sampler, num_workers=num_workers)
                self.test_loader = test_loader
            if (epoch + 1) % 5 == 0:
                test_acc = self._compute_accuracy(self._network, test_loader)
                info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
                    self._cur_task,
                    epoch + 1,
                    self.args['tuned_epoch'],
                    losses / len(train_loader),
                    train_acc,
                    test_acc,
                )
            else:
                info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
                    self._cur_task,
                    epoch + 1,
                    self.args['tuned_epoch'],
                    losses / len(train_loader),
                    train_acc,
                )
            prog_bar.set_description(info)

        logging.info(info)

    def _eval_cnn(self, loader):
        self._network.eval()
        y_pred, y_true = [], []
        for _, (_, inputs, targets) in enumerate(loader):
            inputs = inputs.to(self._device)
            with torch.no_grad():
                outputs = self._network(inputs, task_id=self._cur_task)["logits"][:, :self._total_classes]
            predicts = torch.topk(
                outputs, k=self.topk, dim=1, largest=True, sorted=True
            )[
                1
            ]  # [bs, topk]
            y_pred.append(predicts.cpu().numpy())
            y_true.append(targets.cpu().numpy())

        return np.concatenate(y_pred), np.concatenate(y_true)  # [N, topk]

    def _compute_accuracy(self, model, loader):
        model.eval()
        correct, total = 0, 0
        for i, (_, inputs, targets) in enumerate(loader):
            inputs = inputs.to(self._device)
            with torch.no_grad():
                outputs = model(inputs, task_id=self._cur_task)["logits"][:, :self._total_classes]
            predicts = torch.max(outputs, dim=1)[1]
            correct += (predicts.cpu() == targets).sum()
            total += len(targets)

        return np.around(tensor2numpy(correct) * 100 / total, decimals=2)