from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, RandomSampler
from torch.optim import SGD
import gc
import copy
import math
import torch
from sklearn.metrics import accuracy_score
import numpy as np
from bert_model import BertForSequenceClassification
from transformers import BertConfig
from  itertools import cycle
class Learner(nn.Module):
    """
    Meta Learner
    """
    def __init__(self, args):
        """
        :param args:
        """
        super(Learner, self).__init__()
        self.args = args
        self.num_labels = args.num_labels
        self.outer_update_lr = args.outer_update_lr
        self.inner_update_lr = args.inner_update_lr
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        config = BertConfig(hidden_size=args.hidden_size, num_hidden_layers=4, num_attention_heads=4,
                            intermediate_size=2048)
        self.model = BertForSequenceClassification(config, self.num_labels)
        self.v = [torch.ones_like(param).to(self.device) for param in self.model.parameters()]
        self.reg = torch.tensor(1e-4).to(self.device)
        self.reg.requires_grad = True
        self.hyper_momentum = torch.zeros(1).to(self.device)
        self.hyper_grad_last = torch.zeros(1).to(self.device)
        self.hyper_grad_diff_grad_norm = torch.zeros(1).to(self.device)
        self.grad_y_norm = [torch.zeros(1).to(self.device) for _ in self.model.parameters()]
        self.eta_y_t = [self.inner_update_lr for _ in self.model.parameters()]
        self.eta_v_t = [self.inner_update_lr for _ in self.model.parameters()]
        self.eta_x_t = self.args.outer_update_lr
        self.model.train()
        self.criterion = nn.CrossEntropyLoss().to(self.device)

    def forward(self, train_loader, val_loader):
        task_accs = []
        task_loss = []
        self.model.to(self.device)
        all_loss = []
        val_batch = cycle(val_loader)
        for idx, batch in enumerate(train_loader):
            tr_batch = {k: v.to(self.device) for k, v in batch.items()}
            outputs = self.model(input_ids=tr_batch['input_ids'], attention_mask=tr_batch['attention_mask'])
            inner_loss = self.criterion(outputs.view(-1, self.num_labels), tr_batch["labels"]) + self.reg * sum(
                [x.norm().pow(2) for x in self.model.parameters()])

            grad_y = torch.autograd.grad(inner_loss, self.model.parameters())
            for i, (param, g_param) in enumerate(zip(self.model.parameters(), grad_y)):
                self.grad_y_norm[i] += g_param.norm(2)
                self.eta_y_t[i] = math.sqrt(self.eta_y_t[i]**2+self.grad_y_norm[i].pow(2))
                param.data = param.data - (1.0/self.eta_y_t[i]) * grad_y[i]
            all_loss.append(inner_loss.item())

            vl_batch = {k: v.to(self.device) for k, v in val_batch.__next__().items()}
            val_outputs = self.model(input_ids=vl_batch['input_ids'], attention_mask=vl_batch['attention_mask'])
            outer_loss = self.criterion(val_outputs.view(-1, self.num_labels), vl_batch["labels"])
            hypergrad, phi = self.hypergradient(outer_loss, vl_batch, tr_batch)

            # update the upper-level variables
            eta_x_t = phi * self.eta_x_t
            self.reg.data = self.reg.data - (1.0/eta_x_t) * hypergrad

            val_logits = F.softmax(val_outputs, dim=1)
            pre_label_id = torch.argmax(val_logits, dim=1)
            pre_label_id = pre_label_id.detach().cpu().numpy().tolist()
            val_label_id = vl_batch["labels"].detach().cpu().numpy().tolist()
            acc = accuracy_score(pre_label_id, val_label_id)
            task_accs.append(acc)
            task_loss.append(outer_loss.detach().cpu())
            torch.cuda.empty_cache()
            print(f'{self.args.methods} Trian loss: {inner_loss.item():.4f}, Val loss: {outer_loss.item():.4f}')

        return np.mean(task_accs),  np.mean(task_loss)

    def hypergradient(self, out_loss, val_batch, tr_batch):
        Fy_gradient = torch.autograd.grad(out_loss, self.model.parameters())
        
        outputs = self.model(input_ids=tr_batch["input_ids"], attention_mask=tr_batch["attention_mask"])
        inner_loss = F.cross_entropy(outputs.view(-1, self.args.num_labels), tr_batch["labels"]) + self.reg * sum(
            [x.norm().pow(2) for x in self.model.parameters()])
        Gy_gradient = torch.autograd.grad(inner_loss, self.model.parameters(), create_graph=True)
        Gyyv_gradient = torch.autograd.grad(Gy_gradient, self.model.parameters(), grad_outputs=self.v)
        for i in range(len(self.v)):
            self.eta_v_t[i]= math.sqrt(self.eta_v_t[i]**2 + (Gyyv_gradient[i].detach()-Fy_gradient[i].detach()).norm(2).pow(2))
            phi = max(self.eta_v_t[i], self.eta_y_t[i])
            self.v[i] -= (1.0/phi)*(Gyyv_gradient[i].detach() - Fy_gradient[i].detach())
        torch.cuda.empty_cache()
        # Gyx_gradient
        outputs = self.model(input_ids=val_batch['input_ids'], attention_mask=val_batch['attention_mask'])
        loss = F.cross_entropy(outputs.view(-1, self.args.num_labels), val_batch["labels"]) + self.reg * sum(
            [x.norm().pow(2) for x in self.model.parameters()])
        Gy_gradient = torch.autograd.grad(-loss, self.model.parameters(), retain_graph=True, create_graph=True)
        Gyxv_gradient = torch.autograd.grad(Gy_gradient, self.reg, grad_outputs=self.v)
        self.eta_x_t = math.sqrt(self.eta_x_t**2 + (Gyxv_gradient[0].detach()).norm(2).pow(2))
        return Gyxv_gradient[0], phi


    def evaluate(self, dataloader):
        self.model.eval()
        total_loss = 0
        total_correct = 0
        total_samples = 0

        criterion = torch.nn.CrossEntropyLoss()

        with torch.no_grad():
            for batch in dataloader:
                batch = {k: v.to(self.device) for k, v in batch.items()}
                logits = self.model(input_ids=batch["input_ids"],
                                attention_mask=batch["attention_mask"])
                labels = batch["labels"]
                loss = criterion(logits, labels)

                preds = logits.argmax(dim=-1)
                total_correct += (preds == labels).sum().item()
                total_loss += loss.item()
                total_samples += labels.size(0)

        avg_loss = total_loss / len(dataloader)
        accuracy = total_correct / total_samples
        return avg_loss, accuracy