import collections
import itertools
import numpy as np
from scipy.stats import pearsonr
import torch 
from tqdm import tqdm

def compute_f1(prediction, ground_truth):
    pred_tokens = prediction.strip().lower().split()
    gt_tokens = ground_truth.strip().lower().split()
    
    if len(pred_tokens) == 0 or len(gt_tokens) == 0:
        # If either is empty, then F1 is 1 if they both are empty, else 0.
        return int(pred_tokens == gt_tokens)
    
    common = collections.Counter(pred_tokens) & collections.Counter(gt_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0.0
    precision = num_same / len(pred_tokens)
    recall = num_same / len(gt_tokens)
    f1 = 2 * precision * recall / (precision + recall)
    return f1

def process_batch_qa(batch):
    prompts = []
    targets = []
    if isinstance(batch, dict):
        questions = batch.get("question", [])
        contexts = batch.get("context", [])
        answers_list = batch.get("answers", [])
        for question, context, answers in zip(questions, contexts, answers_list):
            prompt = f"question: {question} context: {context}"
            prompts.append(prompt)
            answer = answers[0] if answers else ""
            targets.append(answer)
    else:
        for ex in batch:
            prompt = f"question: {ex['question']} context: {ex['context']}"
            prompts.append(prompt)
            answer = ex['answers'][0] if ex['answers'] else ""
            targets.append(answer)
    return prompts, targets

def local_train_ic(model, trainloader, opt, steps=0, update=True, batch=None):
    model.train()

    # This is specifically for when we need to calculate the per-head gradients in order to determine which heads to fine-tune 
    if batch != None: 
        trainloader = batch 

    # If steps is 0, go through the entire trainloader once.
    if steps == 0:
        for inputs, targets in trainloader:
            inputs = inputs.cuda()
            targets = targets.cuda()

            with torch.set_grad_enabled(True):
                outputs = model(inputs).logits
                loss = torch.nn.functional.cross_entropy(outputs, targets)
            
            opt.zero_grad()
            loss.backward()   
            if update: 
                opt.step() 
    else:
        # Use itertools.cycle to cycle through the trainloader indefinitely.
        for i, (inputs, targets) in enumerate(itertools.cycle(trainloader)):
            if i >= steps:
                break
            inputs = inputs.cuda()
            targets = targets.cuda()

            with torch.set_grad_enabled(True):
                outputs = model(inputs).logits
                loss = torch.nn.functional.cross_entropy(outputs, targets)
            
            opt.zero_grad()
            loss.backward()   
            if update: 
                opt.step() 

    return model

def eval_ic(model, testloader): 
    model.eval()
    total_err = 0
    total_counts = 0

    for inputs, targets in testloader: 
        """ Uncomment these lines if GPU training """
        inputs = inputs.cuda()
        targets = targets.cuda() 

        with torch.set_grad_enabled(False):
            outputs = model(inputs).logits
        

        err = err = (targets != outputs.argmax(1)).sum()
        counts = torch.ones(len(inputs))

        total_err += err 
        total_counts += counts.sum()

    accuracy = (1 - total_err/total_counts).item()
    print(f"Eval accuracy: {accuracy}")
    return accuracy

def local_train_sc(model, trainloader, opt, tokenizer, steps=0, update=True, batch=None, is_reg=False):
    model.train()
    # This is specifically for when we need to calculate the per-head gradients in order to determine which heads to fine-tune 
    if batch != None: 
        trainloader = batch 

    if steps == 0:
        for texts, targets in trainloader:
            # Tokenize the batch of texts
            encoding = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=256)
            input_ids = encoding['input_ids'].cuda()
            attention_mask = encoding['attention_mask'].cuda()
            targets = targets.cuda()
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask).logits
            if is_reg: 
                torch.nn.functional.mse_loss(outputs, targets)
            else: 
                loss = torch.nn.functional.cross_entropy(outputs, targets)
            
            opt.zero_grad()
            loss.backward()
            if update: 
                opt.step()
    else:
        for i, (texts, targets) in enumerate(itertools.cycle(trainloader)):
            if i >= steps:
                break
            encoding = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=128)
            input_ids = encoding['input_ids'].cuda()
            attention_mask = encoding['attention_mask'].cuda()
            targets = targets.cuda()
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask).logits
            loss = torch.nn.functional.cross_entropy(outputs, targets)
            
            opt.zero_grad()
            loss.backward()
            if update: 
                opt.step()
    return model

def eval_sc(model, testloader, tokenizer, is_reg=False):
    model.eval()
    total_err = 0
    total_counts = 0
    total_pearson = 0
    with torch.no_grad():
        for texts, targets in testloader:
            encoding = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=128)
            input_ids = encoding['input_ids'].cuda()
            attention_mask = encoding['attention_mask'].cuda()
            targets = targets.cuda()
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask).logits
            if is_reg:
                total_pearson += pearsonr(outputs.flatten().cpu(), targets.flatten().cpu())[0]
                total_counts += 1
            else: 
                predictions = outputs.argmax(dim=1)
                total_err += (predictions != targets).sum().item()
                total_counts += targets.size(0)
    if is_reg: 
        pearson = total_pearson / total_counts
        print(f"Eval pearson correlations: {pearson}")
        return pearson
    else: 
        accuracy = 1 - total_err / total_counts
        print(f"Eval accuracy: {accuracy}")
        return accuracy

def local_train_qa(model, trainloader, opt, tokenizer, steps=0, update=True, batch=None):
    model.train()

    # This is specifically for when we need to calculate the per-head gradients in order to determine which heads to fine-tune 
    if batch != None: 
        trainloader = batch 
    
    if steps == 0:
        for batch in trainloader:
            prompts, target_texts = process_batch_qa(batch)
            encoding = tokenizer(prompts, return_tensors='pt', padding=True, truncation=True, max_length=256)
            target_encoding = tokenizer(target_texts, return_tensors='pt', padding=True, truncation=True, max_length=128)
            input_ids = encoding['input_ids'].cuda()
            attention_mask = encoding['attention_mask'].cuda()
            labels = target_encoding['input_ids'].cuda()
            # Replace padding token IDs with -100 so they are ignored in loss computation.
            labels[labels == tokenizer.pad_token_id] = -100

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss

            opt.zero_grad()
            loss.backward()   
            if update: 
                opt.step() 
    else:
        for i, batch in enumerate(itertools.cycle(trainloader)):
            if i >= steps:
                break
            prompts, target_texts = process_batch_qa(batch)
            encoding = tokenizer(prompts, return_tensors='pt', padding=True, truncation=True, max_length=256)
            target_encoding = tokenizer(target_texts, return_tensors='pt', padding=True, truncation=True, max_length=128)
            input_ids = encoding['input_ids'].cuda()
            attention_mask = encoding['attention_mask'].cuda()
            labels = target_encoding['input_ids'].cuda()
            labels[labels == tokenizer.pad_token_id] = -100

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss

            opt.zero_grad()
            loss.backward()   
            if update: 
                opt.step() 
    return model

def eval_qa(model, testloader, tokenizer):
    model.eval()
    total_f1 = 0.0
    total_examples = 0
    # total_evals = 0 
    with torch.no_grad():
        for batch in tqdm(testloader):
            prompts, references = process_batch_qa(batch)
            encoding = tokenizer(prompts, return_tensors='pt', padding=True, truncation=True, max_length=256)
            input_ids = encoding['input_ids'].cuda()
            attention_mask = encoding['attention_mask'].cuda()

            generated_ids = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=128)
            predictions = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            
            for pred, ref in zip(predictions, references):
                f1 = compute_f1(pred, ref)
                total_f1 += f1
                total_examples += 1
    
    avg_f1 = total_f1 / total_examples if total_examples > 0 else 0.0
    print(f"Eval F1: {avg_f1}")
    return avg_f1

""" Combine train and eval into a single function """

def train(model, trainloader, opt, dataset, tokenizer=None, steps=0, update=True, batch=None): 
    if dataset.startswith("glue_"):
        is_reg = dataset.endswith("_stsb")
        model = local_train_sc(model, trainloader, opt, tokenizer, steps=steps, update=update, batch=batch, is_reg=is_reg)
    else: 
      if dataset == 'cifar10' or dataset == 'svhn' or dataset == 'cifar100': 
          model = local_train_ic(model, trainloader, opt, steps=steps, update=update, batch=batch)
      elif dataset == '20newsgroups': 
          model = local_train_sc(model, trainloader, opt, tokenizer, steps=steps, update=update, batch=batch)
      elif dataset == 'mrqa': 
          model = local_train_qa(model, trainloader, opt, tokenizer, steps=steps, update=update, batch=batch)
      else: 
          raise NotImplementedError() 
    
    return model

def eval(model, testloader, dataset, tokenizer=None): 
    if dataset.startswith("glue_"):
        is_reg = dataset.endswith("_stsb")
        acc = eval_sc(model, testloader, tokenizer, is_reg=is_reg)
    else: 
      if dataset == 'cifar10' or dataset == 'svhn' or dataset == 'cifar100': 
          acc = eval_ic(model, testloader)
      elif dataset == '20newsgroups': 
          acc = eval_sc(model, testloader, tokenizer)
      elif dataset == 'mrqa': 
          acc = eval_qa(model, testloader, tokenizer)
      else: 
          raise NotImplementedError()
    
    return acc 

def create_optimizer(model, opt, lr, momentum): 
    if opt == 'sgd': 
        optimizer = torch.optim.SGD([parameter for parameter in model.parameters() if parameter.requires_grad], lr=lr, momentum=momentum)
    elif opt == 'adam': 
        optimizer = torch.optim.Adam([parameter for parameter in model.parameters() if parameter.requires_grad], lr=lr)
    
    return optimizer