import vegas
import logging
import torch
import torch.nn as nn
import numpy as np
import math
import torch.utils.data as data
from math import exp as exp
from sklearn.metrics import roc_auc_score
from scipy import integrate
from CAT.model.abstract_model import AbstractModel
from CAT.dataset import AdapTestDataset, TrainDataset, Dataset
from sklearn.metrics import accuracy_score
from collections import namedtuple
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])

class IRT(nn.Module):
    def __init__(self, num_students, num_questions, num_dim=1):
        # num_dim: IRT if num_dim == 1
        super().__init__()
        self.num_dim = num_dim
        self.num_students = num_students
        self.num_questions = num_questions
        self.theta = nn.Embedding(self.num_students, self.num_dim)
        self.alpha = nn.Embedding(self.num_questions, self.num_dim)
        self.beta = nn.Embedding(self.num_questions, 1)

        for name, param in self.named_parameters():
            if 'weight' in name:
                nn.init.xavier_normal_(param)

    def forward(self, student_ids, question_ids):
        theta = self.theta(student_ids)
        alpha = self.alpha(question_ids)
        beta = self.beta(question_ids)
        pred = (alpha * theta).sum(dim=1, keepdim=True) + beta
        pred = torch.sigmoid(pred)
        return pred

class IRTModel(AbstractModel):
    def __init__(self, **config):
        super().__init__()
        self.config = config
        self.model = None

    @property
    def name(self):
        return 'Item Response Theory'

    def init_model(self, data: Dataset):
        policy_lr = 0.0005
        self.model = IRT(data.num_students, data.num_questions)
        self.n_q = data.num_questions


    def train(self, train_data: TrainDataset, log_step=1, tensorboard_dir=None):
        lr = self.config['learning_rate']
        batch_size = self.config['batch_size']
        epochs = self.config['num_epochs']
        device = self.config['device']
        self.model.to(device)
        logging.info('train on {}'.format(device))

        writer = SummaryWriter(tensorboard_dir) if tensorboard_dir else None

        train_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

        for ep in range(1, epochs + 1):
            loss = 0.0

            for cnt, (student_ids, question_ids, _, labels) in enumerate(train_loader):
                student_ids = student_ids.to(device)
                question_ids = question_ids.to(device)
                labels = labels.to(device).float()
                pred = self.model(student_ids, question_ids).view(-1)
                bz_loss = self._loss_function(pred, labels)
                optimizer.zero_grad()
                bz_loss.backward()
                optimizer.step()
                loss += bz_loss.data.float()

                if cnt % log_step == 0:
                    logging.info('Epoch [{}] Batch [{}]: loss={:.5f}'.format(ep, cnt, loss / cnt))
            if writer:
                writer.add_scalar('Loss/train_avg', loss / cnt, ep)


    def adaptest_save(self, path):
        """
        Save the model. Only save the parameters of questions(alpha, beta)
        """
        model_dict = self.model.state_dict()
        model_dict = {k:v for k,v in model_dict.items() if 'alpha' in k or 'beta' in k}
        torch.save(model_dict, path)

    def adaptest_load(self, path):
        """
        Reload the saved model
        """
        self.model.load_state_dict(torch.load(path), strict=False)
        self.model.to(self.config['device'])

    def adaptest_update(self, adaptest_data: AdapTestDataset, sid=None):
        """
        Update CDM with tested data
        """
        lr = self.config['learning_rate']
        batch_size = self.config['batch_size']
        epochs = self.config['num_epochs']
        device = self.config['device']
        optimizer = torch.optim.Adam(self.model.theta.parameters(), lr=lr)

        tested_dataset = adaptest_data.get_tested_dataset(last=False,ssid=sid)

        dataloader = torch.utils.data.DataLoader(tested_dataset, batch_size=batch_size, shuffle=True)
        for ep in range(1, epochs + 1):
            loss = 0.0
            log_steps = 100
            for cnt, (student_ids, question_ids, _, labels) in enumerate(dataloader):
                student_ids = student_ids.to(device)
                question_ids = question_ids.to(device)
                labels = labels.to(device).float()
                pred = self.model(student_ids, question_ids).view(-1)
                bz_loss = self._loss_function(pred, labels)
                optimizer.zero_grad()
                bz_loss.backward()
                optimizer.step()
                loss += bz_loss.data.float()
                # if cnt % log_steps == 0:
                    # print('Epoch [{}] Batch [{}]: loss={:.3f}'.format(ep, cnt, loss / cnt))
        return loss

    def bias_adjust(self, adaptest_data: AdapTestDataset, S_set):
        """ Apply bias correction to adjust for response errors
        Args:
            S_set : dict
        """
        pred_all = self.get_pred(adaptest_data)
        sum_bias = 0.0
        for sid in range(adaptest_data.num_students):
            q_idx = S_set[sid][-1]
            bias = self.get_bias(adaptest_data,pred_all,S_set[sid],q_idx,sid)
            self.model.theta.weight.data[sid] = torch.tensor(self.get_theta(sid).item() - bias,dtype=torch.float64)

    def one_student_update(self, adaptest_data: AdapTestDataset):
        lr = self.config['learning_rate']
        batch_size = self.config['batch_size']
        epochs = self.config['num_epochs']
        device = self.config['device']
        optimizer = torch.optim.Adam(self.model.theta.parameters(), lr=lr)

    def evaluate(self, adaptest_data: AdapTestDataset):
        data = adaptest_data.data
        concept_map = adaptest_data.concept_map
        device = self.config['device']

        real = []
        pred = []
        with torch.no_grad():
            self.model.eval()
            for sid in data:
                student_ids = [sid] * len(data[sid])
                question_ids = list(data[sid].keys())
                real += [data[sid][qid] for qid in question_ids]
                student_ids = torch.LongTensor(student_ids).to(device)
                question_ids = torch.LongTensor(question_ids).to(device)
                output = self.model(student_ids, question_ids).view(-1)
                pred += output.tolist()
            self.model.train()

        coverages = []
        for sid in data:
            all_concepts = set()
            tested_concepts = set()
            for qid in data[sid]:
                all_concepts.update(set(concept_map[qid]))
            for qid in adaptest_data.tested[sid]:
                tested_concepts.update(set(concept_map[qid]))
            coverage = len(tested_concepts) / len(all_concepts)
            coverages.append(coverage)
        cov = sum(coverages) / len(coverages)

        real = np.array(real)
        pred = np.array(pred)
        auc = roc_auc_score(real, pred)
        
        # Calculate accuracy
        threshold = 0.5  # You may adjust the threshold based on your use case
        binary_pred = (pred >= threshold).astype(int)
        acc = accuracy_score(real, binary_pred)

        return {
            'auc': auc,
            'cov': cov,
            'acc': acc
        }

    def get_pred(self, adaptest_data: AdapTestDataset):
        """
        Returns:
            predictions, dict[sid][qid]
        """
        data = adaptest_data.data
        concept_map = adaptest_data.concept_map
        device = self.config['device']

        pred_all = {}

        with torch.no_grad():
            self.model.eval()
            for sid in data:
                pred_all[sid] = {}
                student_ids = [sid] * len(data[sid])
                question_ids = list(data[sid].keys())
                student_ids = torch.LongTensor(student_ids).to(device)
                question_ids = torch.LongTensor(question_ids).to(device)
                output = self.model(student_ids, question_ids).view(-1).tolist()
                for i, qid in enumerate(list(data[sid].keys())):
                    pred_all[sid][qid] = output[i]
            self.model.train()

        return pred_all

    def _loss_function(self, pred, real):
        return -(real * torch.log(0.0001 + pred) + (1 - real) * torch.log(1.0001 - pred)).mean()
    
    def get_alpha(self, question_id):
        """ get alpha of one question
        Args:
            question_id: int, question id
        Returns:
            alpha of the given question, shape (num_dim, )
        """
        return self.model.alpha.weight.data.cpu().numpy()[question_id]
    
    def get_beta(self, question_id):
        """ get beta of one question
        Args:
            question_id: int, question id
        Returns:
            beta of the given question, shape (1, )
        """
        return self.model.beta.weight.data.cpu().numpy()[question_id]
    
    def get_theta(self, student_id):
        """ get theta of one student
        Args:
            student_id: int, student id
        Returns:
            theta of the given student, shape (num_dim, )
        """
        return self.model.theta.weight.data.cpu().numpy()[student_id]

    def get_bias(self, adaptest_data: AdapTestDataset, pred_all, S_set, q_idx, sid):
        """ get bias item
        Args:
            pred_all: dict, the questions you want to sample and their probability
            S_set : list
            q_idx : int, question id
            sid: int, student id
        Returns:
            bias item
        """
        pi_g = self.config['pi_g']
        pi_s = self.config['pi_s']

        S_set = list(S_set)
        S_set_q = S_set.copy()
        S_set_q.remove(q_idx)

        _,h_i_without_q = self.get_G_H(sid,pred_all,S_set_q)
        h_i_last_q = self.get_H_bias(pred_all,sid,q_idx)

        q_label = adaptest_data.get_score_by_sid_qid(sid,q_idx)
        q_pre = pred_all[sid][q_idx]
        t = h_i_without_q + h_i_last_q

        return (pi_g * (1 - q_label) + pi_s * q_label) * 1/t * (1 - 2 * q_label) * math.log(q_pre / (1 - q_pre))

    def get_H_bias(self, pred_all, sid, last_q_idx):
        """ get  H~ = (∇^2 ℓ_i(θ)~)
        Args:
            pred_all: dict, the questions you want to sample and their probability
            sid: int, student id
            last_q_idx: the last question id
        Returns:
            H~ = (∇^2 ℓ_i(θ)~)
        """
        return self.get_H(sid,last_q_idx,pred_all)

    def get_H(self, student_id, question_id, pred_all):
        """ get H
        Args:
            student_id: int, student id
            question_id: int, question id
            pred_all: dict, the questions you want to sample and their probability
        Returns:
            H: Hessian matrix
        """
        device = self.config['device']
        qid = torch.LongTensor([question_id]).to(device)
        alpha = self.model.alpha(qid).clone().detach().cpu()
        pred = pred_all[student_id][question_id]
        q = 1 - pred
        h = (q*pred*(alpha * alpha.T)).numpy()
        return h

    def get_G_H(self, sid, pred_all, S_set):
        """ get G=∑∇ℓ_i(θ) and H=∑∇^2 ℓ_i(θ)
        Args:
            sid: int, student id
            pred_all: dict, the questions you want to sample and their probability
            S_set : list
        Returns:
            G & H
        """
        if len(S_set)==0:
            return 0,0

        Pre_true={}
        Pre_false={}
        for qid, pred in pred_all[sid].items():
            Pre_true[qid] = pred
            Pre_false[qid] = 1 - pred

        l_i = 0.0
        h_i = 0.0

        for i,_ in pred_all[sid].items():
            if(i in S_set):
                # Calculate 2 possible gradient combinations
                gradients_theta1 = (Pre_true[i]-0.0) * self.get_alpha(i) #  Derivative of loss function L with respect to theta: (Pi-yi)*alpha
                gradients_theta2 = (Pre_true[i]-1.0) * self.get_alpha(i)

                h_gradients_theta1 = self.get_H(sid,i,pred_all)
                # Computing the expected gradient
                Expect = Pre_true[i] * gradients_theta1 + Pre_false[i] * gradients_theta2
                l_i = l_i + Expect
                h_i = h_i + h_gradients_theta1
        return l_i,h_i


    def E_f_S_t(self, sid, question_id, pred_all, S_set):
        """ get CFAT Strategy Questions delta
        Args:
            sid: int, student id
            question_id: int, question id
            pred_all: dict, the untest questions and their probability
            S_set: dict, chosen questions
        Returns:
            E_f_S_t: float, E_f_S_t of questions id
        """
        Sp_set = list(S_set)
        l_i, h_i = self.get_G_H(sid,pred_all,Sp_set)

        h_i = h_i + self.get_H(sid,question_id,pred_all)

        gradients_theta1 = (pred_all[sid][question_id] - 0.0) * self.get_alpha(question_id)
        gradients_theta2 = (pred_all[sid][question_id] - 1.0) * self.get_alpha(question_id)

        # Calculate the expected gradient
        Expect = pred_all[sid][question_id] * gradients_theta1 + (1 - pred_all[sid][question_id]) * gradients_theta2
        l_i = l_i + Expect

        F_sp = l_i * 1/h_i

        return abs(F_sp)




