from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, RandomSampler
from torch.optim import SGD
import gc
import copy
import math
import torch
from sklearn.metrics import accuracy_score
import numpy as np
from bert_model import BertForSequenceClassification
from transformers import BertConfig
from  itertools import cycle
class Learner(nn.Module):
    """
    Meta Learner
    """
    def __init__(self, args):
        """
        :param args:
        """
        super(Learner, self).__init__()
        self.args = args
        self.num_labels = args.num_labels
        self.outer_update_lr = args.outer_update_lr
        self.inner_update_lr = args.inner_update_lr
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        config = BertConfig(hidden_size=args.hidden_size, num_hidden_layers=4, num_attention_heads=4,
                            intermediate_size=2048)
        self.model = BertForSequenceClassification(config, self.num_labels)

        self.reg = torch.tensor(1e-4).to(self.device)
        self.reg.requires_grad = True
        self.hyper_momentum = torch.zeros(1).to(self.device)
        self.hyper_grad_last = torch.zeros(1).to(self.device)
        self.hyper_grad_diff_grad_norm = torch.zeros(1).to(self.device)
        self.grad_y_norm = [torch.zeros(1).to(self.device) for _ in self.model.parameters()]
        self.model.train()
        self.criterion = nn.CrossEntropyLoss().to(self.device)

    def forward(self, train_loader, val_loader):
        task_accs = []
        task_loss = []
        self.model.to(self.device)
        all_loss = []
        val_batch = cycle(val_loader)
        for idx, batch in enumerate(train_loader):
            tr_batch = {k: v.to(self.device) for k, v in batch.items()}
            outputs = self.model(input_ids=tr_batch['input_ids'], attention_mask=tr_batch['attention_mask'])
            inner_loss = self.criterion(outputs.view(-1, self.num_labels), tr_batch["labels"]) + self.reg * sum(
                [x.norm().pow(2) for x in self.model.parameters()])

            grad_y = torch.autograd.grad(inner_loss, self.model.parameters())
            for i, (param, g_param) in enumerate(zip(self.model.parameters(), grad_y)):
                self.grad_y_norm[i] +=  g_param.norm(2).pow(2)
                eta_y_t = self.args.inner_update_lr  / math.sqrt(self.grad_y_norm[i] + self.args.gamma**2)
                param.data = param.data - eta_y_t * grad_y[i]
            all_loss.append(inner_loss.item())

            vl_batch = {k: v.to(self.device) for k, v in val_batch.__next__().items()}
            val_outputs = self.model(input_ids=vl_batch['input_ids'], attention_mask=vl_batch['attention_mask'])
            outer_loss = self.criterion(val_outputs.view(-1, self.num_labels), vl_batch["labels"])
            hypergrad = self.hypergradient(outer_loss, vl_batch, tr_batch)[0]

            # update the upper-level variables
            self.hyper_grad_diff_grad_norm += (hypergrad - self.hyper_grad_last).norm(2).pow(2)
            self.hyper_grad_last = copy.deepcopy(hypergrad)
            alpha_t = self.args.alpha / torch.sqrt(self.args.alpha**2 + self.hyper_grad_diff_grad_norm)
            alpha_t_prime = self.args.alpha / torch.sqrt(self.args.alpha**2 + self.hyper_grad_diff_grad_norm + sum([g_y for g_y in self.grad_y_norm]))
            eta_x_t = self.args.outer_update_lr * math.sqrt(alpha_t_prime)
            self.hyper_momentum = (1 - alpha_t) * self.hyper_momentum + alpha_t * hypergrad
            self.reg.data = self.reg.data - eta_x_t * self.hyper_momentum/self.hyper_momentum.norm(2)

            val_logits = F.softmax(val_outputs, dim=1)
            pre_label_id = torch.argmax(val_logits, dim=1)
            pre_label_id = pre_label_id.detach().cpu().numpy().tolist()
            val_label_id = vl_batch["labels"].detach().cpu().numpy().tolist()
            acc = accuracy_score(pre_label_id, val_label_id)
            task_accs.append(acc)
            task_loss.append(outer_loss.detach().cpu())
            torch.cuda.empty_cache()
            print(f'{self.args.methods} Trian loss: {inner_loss.item():.4f}, Val loss: {outer_loss.item():.4f}')

        return np.mean(task_accs),  np.mean(task_loss)

    def hypergradient(self, out_loss, val_batch, tr_batch):
        Fy_gradient = torch.autograd.grad(out_loss, self.model.parameters(), retain_graph=True)
        F_gradient = [g_param.view(-1) for g_param in Fy_gradient]
        v_0 = torch.unsqueeze(torch.reshape(torch.hstack(F_gradient), [-1]), 1).detach()
        # Fx_gradient = torch.autograd.grad(out_loss, self.reg)

        # calculate the neumann series
        z_list = []
        outputs = self.model(input_ids=tr_batch["input_ids"], attention_mask=tr_batch["attention_mask"])
        inner_loss = F.cross_entropy(outputs.view(-1, self.args.num_labels), tr_batch["labels"]) + self.reg * sum(
            [x.norm().pow(2) for x in self.model.parameters()])
        G_gradient = []
        Gy_gradient = torch.autograd.grad(inner_loss, self.model.parameters(), create_graph=True)
        for g_grad, param in zip(Gy_gradient, self.model.parameters()):
            G_gradient.append((param - self.args.neumann_lr * g_grad).view(-1))
        G_gradient = torch.reshape(torch.hstack(G_gradient), [-1])

        for _ in range(self.args.hessian_q):
            Jacobian = torch.matmul(G_gradient, v_0)
            v_new = torch.autograd.grad(Jacobian, self.model.parameters(), retain_graph=True)
            v_params = [v_param.view(-1) for v_param in v_new]
            v_0 = torch.unsqueeze(torch.reshape(torch.hstack(v_params), [-1]), 1).detach()
            z_list.append(v_0)
        v_Q = self.args.neumann_lr * (v_0 + torch.sum(torch.stack(z_list), dim=0))

        # Gyx_gradient
        outputs = self.model(input_ids=val_batch['input_ids'], attention_mask=val_batch['attention_mask'])
        loss = F.cross_entropy(outputs.view(-1, self.args.num_labels), val_batch["labels"]) + self.reg * sum(
            [x.norm().pow(2) for x in self.model.parameters()])
        Gy_gradient = torch.autograd.grad(loss, self.model.parameters(), retain_graph=True, create_graph=True)
        Gy_params = [Gy_param.view(-1) for Gy_param in Gy_gradient]
        Gy_gradient_flat = torch.reshape(torch.hstack(Gy_params), [-1])
        Gyxv_gradient = torch.autograd.grad(-torch.matmul(Gy_gradient_flat, v_Q.detach()), self.reg)
        return Gyxv_gradient


    def evaluate(self, dataloader):
        self.model.eval()
        total_loss = 0
        total_correct = 0
        total_samples = 0

        criterion = torch.nn.CrossEntropyLoss()

        with torch.no_grad():
            for batch in dataloader:
                batch = {k: v.to(self.device) for k, v in batch.items()}
                logits = self.model(input_ids=batch["input_ids"],
                                attention_mask=batch["attention_mask"])
                labels = batch["labels"]
                loss = criterion(logits, labels)

                preds = logits.argmax(dim=-1)
                total_correct += (preds == labels).sum().item()
                total_loss += loss.item()
                total_samples += labels.size(0)

        avg_loss = total_loss / len(dataloader)
        accuracy = total_correct / total_samples
        return avg_loss, accuracy