import os
import random
import torch
import torch.nn as nn
import numpy as np
import shutil
import datetime
import copy
from tools import *
from utils import MyConfig
import timm
from sklearn.metrics import confusion_matrix

import math
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from typing import Iterable, List, Dict, Tuple, Optional, Union
    
    
def evaluate(model, test_loader, num_classes, device):
    model.eval()
    all_preds = []
    all_labels = []
    model.to(device)
    total_loss = 0.0
    total_count = 0
    
    class_losses = np.zeros(num_classes)
    class_counts = np.zeros(num_classes)

    criterion = nn.CrossEntropyLoss(reduction='none')
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)  
            
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

            losses = criterion(outputs, labels)  

            total_loss += losses.sum().item()
            total_count += labels.size(0)
            
            for label in labels.unique():
                class_mask = (labels == label)
                class_losses[label.item()] += losses[class_mask].sum().item()
                class_counts[label.item()] += class_mask.sum().item()
                
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)

    accuracy = round(np.mean(all_preds == all_labels), 5)
    
    cm = confusion_matrix(all_labels, all_preds, labels=np.arange(num_classes))
    TP = np.diagonal(cm)  
    FP = cm.sum(axis=0) - TP 
    FN = cm.sum(axis=1) - TP

    precision = np.round(TP / (TP + FP), 5)
    recall = np.round(TP / (TP + FN), 5)
    f1_score = np.round(2 * precision * recall / (precision + recall), 5)
    
    macro_precision = round(np.mean(precision), 5)
    macro_recall = round(np.mean(recall), 5)
    macro_f1_score = round(np.mean(f1_score), 5)

    avg_loss = round(total_loss / total_count, 5)
    avg_class_loss = np.round(class_losses / class_counts, 5)
    
    return (cm, accuracy, macro_precision, macro_recall, macro_f1_score, avg_loss, 
            precision, recall, f1_score, avg_class_loss)


def seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def _is_head_key(name: str) -> bool:
    ln = name.lower()
    return ('head' in ln) or ('classifier' in ln)

def _backbone_name_order_from_model(model: torch.nn.Module):
    return [k for k, v in model.state_dict().items()
            if (not _is_head_key(k)) and isinstance(v, torch.Tensor)]

def _make_slices_from_shapes(tensors: List[torch.Tensor]):
    offs, slices = 0, []
    for t in tensors:
        n = t.numel()
        slices.append((offs, offs + n, t.shape))
        offs += n
    return slices

def _flatten_backbone_by_order(model: torch.nn.Module, name_order: List[str]):
    sd = model.state_dict()
    flats = [sd[n].reshape(-1) for n in name_order]
    return torch.cat(flats, dim=0)

def _assign_backbone_from_flat(model: torch.nn.Module,
                               name_order: List[str],
                               flat: torch.Tensor):
    sd = model.state_dict()
    ref_tensors = [sd[n] for n in name_order]
    slices = _make_slices_from_shapes(ref_tensors)
    with torch.no_grad():
        for n, (lo, hi, shp) in zip(name_order, slices):
            sd[n].copy_(flat[lo:hi].view(shp))
    model.load_state_dict(sd, strict=False)

def _taskvector_to_backbone_flat(tv, name_order: List[str], device, dtype):
    pieces = []
    for n in name_order:
        if n not in tv.vector:
            pieces.append(torch.zeros_like(pieces[-1]) if pieces else torch.zeros(0, device=device, dtype=dtype))
        else:
            t = tv.vector[n].to(device=device, dtype=dtype)
            pieces.append(t.reshape(-1))
    return torch.cat(pieces, dim=0) if len(pieces) > 0 else torch.zeros(0, device=device, dtype=dtype)

@torch.no_grad()
def gram_schmidt_orthogonal_then_normalize(
    vectors: List[torch.Tensor],
    target_index: int = 0,
    normalize_target: bool = False,   
    eps: float = 1e-12
):
    
    T = len(vectors)
    order = [target_index] + [i for i in range(T) if i != target_index]
    vec_ord = [vectors[i] for i in order]

    tilde_ord: List[torch.Tensor] = []
    for v in vec_ord:
        u = v.clone()
        for prev in tilde_ord:
            denom = torch.dot(prev, prev)
            if float(denom) > eps:
                u = u - (torch.dot(u, prev) / (denom + eps)) * prev
        tilde_ord.append(u)

    hat_ord: List[torch.Tensor] = []
    for j, u in enumerate(tilde_ord):
        if j == 0 and (not normalize_target):
            hat_ord.append(u)  
        else:
            n = torch.norm(u)
            if float(n) < eps:
                hat_ord.append(torch.zeros_like(u))
            else:
                hat_ord.append(u / (n + eps))

    tilde = [None] * T
    hat = [None] * T
    for j, idx in enumerate(order):
        tilde[idx] = tilde_ord[j]
        hat[idx] = hat_ord[j]

    return tilde, hat


def _entropy_from_logits(logits: torch.Tensor, temperature: float = 1.0):
    if temperature != 1.0:
        logits = logits / temperature
    p = torch.softmax(logits, dim=1)
    return -(p * (p.clamp_min(1e-12).log())).sum(dim=1).mean()

@torch.no_grad()
def _clone_like(model: torch.nn.Module):
    import copy
    m = copy.deepcopy(model)
    for p in m.parameters():
        p.requires_grad_(True)
    return m

def taskwise_adamerging_optimize_alphas_with_fairness(
    theta_pre_template: torch.nn.Module,
    task_vectors: List["TaskVector"],      
    alpha_init: List[float],               
    work_models: List[torch.nn.Module],
    unlabeled_loaders: List[Iterable],
    device: Union[str, torch.device],
    steps: int = 100,
    lr: float = 1e-2,
    eval_batches_per_task: int = 32,
    temperature: float = 1.0,
    alpha_min: float = 1e-6,
    alpha_max: float = 1.0,
    early_stop_tol: float = 1e-5,
    early_stop_patience: int = 10,
    use_adam: bool = True,
    grad_clip_alpha: float = 0.0,
    deltaG: float = 0.0,
    deltaLambda: float = 0.0,
    lambda_fair: float = 0.0,
    target_index: int = 0,
    normalize_tau: bool = True
):
    device = torch.device(device)
    theta_pre_template = theta_pre_template.to(device)
    name_order = _backbone_name_order_from_model(theta_pre_template)
    back_dtype = next(theta_pre_template.parameters()).dtype
    theta_pre_flat = _flatten_backbone_by_order(theta_pre_template, name_order).detach().to(device)
    raw_tau_flats = [
        _taskvector_to_backbone_flat(tv, name_order, device, back_dtype).detach()
        for tv in task_vectors
    ]
    if normalize_tau:
        _, tau_flats = gram_schmidt_orthogonal_then_normalize(
            raw_tau_flats,
            target_index=target_index,
            normalize_target=False,  
            eps=1e-12
        )
    else:
        tau_flats = raw_tau_flats

    alphas = torch.tensor(alpha_init, dtype=torch.float32, device=device).clamp_(min=alpha_min, max=alpha_max)

    m = torch.zeros_like(alphas); v = torch.zeros_like(alphas)
    beta1, beta2, eps = 0.9, 0.999, 1e-8

    history: Dict[str, List] = {
        'entropy': [], 'per_task_entropy': [], 'alphas': [], 'dalpha': [],
        'fair': [], 'delta': []
    }
    
    logC_per_task: List[float] = []
    for wm, ul in zip(work_models, unlabeled_loaders):
        C_t = None
        for batch in ul:
            x = batch[0] if isinstance(batch, (list, tuple)) else batch['x']
            with torch.no_grad():
                logits = wm(x.to(device))
                C_t = logits.shape[1]
            break
        logC_per_task.append(math.log(float(C_t) + 1e-12) if C_t is not None else 1.0)

    best_val, stall = float('inf'), 0

    for s in range(steps):
        theta_cur = theta_pre_flat.clone()
        for a, tflat in zip(alphas, tau_flats):
            if float(a) != 0.0:
                theta_cur.add_(tflat, alpha=float(a))

        for wm in work_models:
            _assign_backbone_from_flat(wm, name_order, theta_cur)

        grad_theta = torch.zeros_like(theta_cur)
        ent_sum, cnt = 0.0, 0
        per_task_ents_step: List[float] = []

        for t_idx, (wm, ul) in enumerate(zip(work_models, unlabeled_loaders)):
            wm.zero_grad(set_to_none=True)
            for n, p in wm.named_parameters():
                p.requires_grad_(not _is_head_key(n))

            used, ent_task = 0, 0.0
            for b_idx, batch in enumerate(ul):
                x = batch[0] if isinstance(batch, (list, tuple)) else batch['x']
                x = x.to(device)
                wm.zero_grad(set_to_none=True)
                logits = wm(x)
                ent = _entropy_from_logits(logits, temperature=temperature)
                norm_ent = ent / (logC_per_task[t_idx] if logC_per_task[t_idx] > 0 else 1.0)

                back_params = [p for n, p in wm.named_parameters() if not _is_head_key(n)]
                g_list = torch.autograd.grad(norm_ent, tuple(back_params),
                                             retain_graph=False, create_graph=False, allow_unused=True)
                g_flat = torch.cat([(torch.zeros_like(p) if g is None else g).reshape(-1)
                                    for g, p in zip(g_list, back_params)])
                grad_theta.add_(g_flat)

                v_ent = float(norm_ent.detach().cpu())
                ent_task += v_ent; ent_sum += v_ent
                used += 1; cnt += 1
                if used >= eval_batches_per_task:
                    break

            per_task_ents_step.append(ent_task / max(1, used))

        ent_val = (ent_sum / cnt) if cnt > 0 else 0.0
        if cnt > 0:
            grad_theta.div_(cnt)

        dent_dalpha = torch.tensor(
            [torch.dot(grad_theta, tflat) for tflat in tau_flats],
            device=device, dtype=torch.float32
        )
        if grad_clip_alpha and grad_clip_alpha > 0:
            dent_dalpha = torch.clamp(dent_dalpha, min=-grad_clip_alpha, max=grad_clip_alpha)
            
        alpha0 = alphas[target_index]
        sum_aux = alphas.sum() - alpha0
        norm_tau0 = torch.norm(tau_flats[target_index]).clamp_min(1e-12)
        Δ = (1.0 - alpha0) * norm_tau0 + sum_aux
        J_fair = float(Δ.item() * deltaG + 0.5 * (Δ.item() ** 2) * deltaLambda)

        common = (deltaG + Δ * deltaLambda)
        dJ = torch.full_like(alphas, common)
        dJ[target_index] = -common * norm_tau0 

        grad_alpha = dent_dalpha + lambda_fair * dJ

        if use_adam:
            m = beta1 * m + (1 - beta1) * grad_alpha
            v = beta2 * v + (1 - beta2) * (grad_alpha * grad_alpha)
            m_hat = m / (1 - beta1 ** (s + 1))
            v_hat = v / (1 - beta2 ** (s + 1))
            alphas = alphas - lr * m_hat / (torch.sqrt(v_hat) + eps)
        else:
            alphas = alphas - lr * grad_alpha

        alphas = alphas.clamp_(min=alpha_min, max=alpha_max)

        cur_obj = ent_val + lambda_fair * J_fair
        history['entropy'].append(ent_val)
        history['per_task_entropy'].append(per_task_ents_step)
        history['alphas'].append([float(x) for x in alphas.tolist()])
        history['dalpha'].append([float(x) for x in grad_alpha.tolist()])
        history['fair'].append(J_fair)
        history['delta'].append(float(Δ.item()))
        print(f'epoch: {s}——entropy: {ent_val}——alphas: {[float(x) for x in alphas.tolist()]}——fair: {J_fair}——delta: {float(Δ.item())}')

        if cur_obj + 1e-12 < best_val - early_stop_tol:
            best_val = cur_obj
            stall = 0
        else:
            stall += 1
            if stall >= early_stop_patience:
                break

    return [float(a) for a in alphas.tolist()], history, tau_flats, name_order, theta_pre_flat




def main():
    now = str(datetime.datetime.now())[:19]
    now = now.replace(":","_")
    now = now.replace("-","_")
    now = now.replace(" ","_")
    
    config_dict = {'CIFAR10': "config/CIFAR10/",
                'CINIC10': "config/CINIC10/",
                'GTSRB': "config/GTSRB/",
                'UTK-Face': "config/UTK-Face/", 
                'SVHN': "config/SVHN/",
                'FER-2013': "config/FER-2013/"}
    
    config = MyConfig.MyConfig(path="config/FairMerging/")
    seed(config.general.seed)
    src_dir = config.path.config_path
    tasks_str = '_'.join(config.general.tasks_list)
    path = config.path.result_path + '/six/FairMerging/' + tasks_str + '_' +str(now)
    os.makedirs(path)
    dst_dir = path+ "/config.yaml"
    shutil.copy(src_dir,dst_dir)
    
    device = torch.device(config.general.device)
            
    configs_list = [MyConfig.MyConfig(path=config_dict[task]) for task in config.general.tasks_list]
    dataloaders_list = []
    for idx in range(len(config.general.tasks_list)):
        dataloader, _ = dataset(config.general.tasks_list[idx], configs_list[idx])
        dataloaders_list.append(dataloader['test'])
    
    pretrained_models_list = [timm.create_model(config.general.model, pretrained=True, num_classes=num_class).to(device) for num_class in config.general.num_classes_list]
    pretrained_state_dict_list = [{k: v.clone() for k, v in m.state_dict().items()} for m in pretrained_models_list]
    finetuned_state_dict_list = [torch.load(f"./saved_models/vit_base_patch32_224_{task}/base_model.pth") for task in config.general.tasks_list]
    task_vectors_list = [TaskVector(pretrained_state_dict_list[idx], finetuned_state_dict_list[idx]) for idx in range(len(config.general.tasks_list))]

    work_models = []
    for model, ft_sd in zip(pretrained_models_list, finetuned_state_dict_list):
        for head_key in ["head.weight", "head.bias", "classifier.weight", "classifier.bias"]:
            if head_key in ft_sd:
                model.state_dict()[head_key].copy_(ft_sd[head_key])
        work_models.append(model)
        
    theta_pre_template = pretrained_models_list[0]
    
    T = len(task_vectors_list)
    alpha_init = [1.0 / T] * T
    
    alphas, hist, tau_flats, name_order, theta_pre_flat = taskwise_adamerging_optimize_alphas_with_fairness(
        theta_pre_template = pretrained_models_list[0],
        task_vectors       = task_vectors_list,
        alpha_init         = alpha_init,
        work_models        = work_models,
        unlabeled_loaders  = dataloaders_list,
        device             = device,
        steps              = config.learning.steps,
        lr                 = config.learning.lr,
        eval_batches_per_task = config.learning.eval_batches_per_task,
        temperature        = config.learning.temperature,
        use_adam           = config.learning.use_adam,
        early_stop_tol     = config.learning.early_stop_tol,
        early_stop_patience= config.learning.early_stop_patience,
        deltaG             = config.learning.deltaG,
        deltaLambda        = config.learning.deltaLambda,
        lambda_fair        = config.learning.lambda_fair,
        target_index       = config.learning.target_index,
        normalize_tau      = config.learning.normalize_tau
    )
    print("Final alphas:", alphas)
    print("Last entropy:", hist["entropy"][-1])
    print("entropy hist:", hist["entropy"])
    print("fair:", hist['fair'])
    print("delta:", hist['delta'])
    np.save(path + "/alphas.npy", alphas)
    np.save(path + "/hist.npy", hist)
    
    theta_merged_flat = theta_pre_flat.clone()
    scaling = float(getattr(config.general, "scaling_coef", 1.0))
    for a, tflat in zip(alphas, tau_flats):
        if float(a) != 0.0:
            theta_merged_flat.add_(tflat, alpha=float(a) * scaling)
    merged_models = []
    for m, sd, ft_sd in zip(pretrained_models_list, pretrained_state_dict_list, finetuned_state_dict_list):
        m.load_state_dict(sd, strict=True) 
        _assign_backbone_from_flat(m, name_order, theta_merged_flat)

        with torch.no_grad():
            for head_key in ["head.weight", "head.bias", "classifier.weight", "classifier.bias"]:
                if head_key in ft_sd and head_key in m.state_dict():
                    m.state_dict()[head_key].copy_(ft_sd[head_key])

        merged_models.append(m)
            
    print('Evaluating the results.........................................................................................................................')
    all_evaluation_results = {}
    for merged_model, dataloader, num_classes, task in zip(merged_models, dataloaders_list, config.general.num_classes_list, config.general.tasks_list):  
            cm, accuracy, macro_precision, macro_recall, macro_f1_score, avg_loss, precision, recall, f1_score, avg_class_loss = evaluate(merged_model, dataloader, num_classes, device)
            evaluation_results = {
                "accuracy": accuracy,
                "avg_loss": avg_loss,
                "per_class_recall": recall,
                "avg_class_loss": avg_class_loss
                }
            all_evaluation_results[task] = evaluation_results
            np.save(path + "/evaluation_results.npy", all_evaluation_results)
    
    
if __name__ == "__main__":
    import multiprocessing as mp
    mp.set_start_method("spawn", force=True)
    main()
    