

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import random
from torch.utils.tensorboard import SummaryWriter
import argparse
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import *
from datetime import datetime
from tasks.task_generator import configure_dataset
from utils.utils import *
import wandb
import os 
from agent import *
from collections import deque
import hydra
from omegaconf import DictConfig
import matplotlib.pyplot as plt
from copy import deepcopy
from contextlib import contextmanager
import random, torch, numpy as np


def save_artifact(matrix, name, wandb_enable, num_tasks, length):
    # 1) Fill non‐list cells with zero‐lists of length k
    for i in range(num_tasks):
        for j in range(num_tasks):
            if len(matrix[i][j]) == 0:
                matrix[i][j] = [0] * length

    os.makedirs('./bwa_data/', exist_ok=True)
    
    np.save(f"./bwa_data/{name}.npy", matrix)

    if wandb_enable:
        artifact = wandb.Artifact(name, type="dataset")
        artifact.add_file(f"./bwa_data/{name}.npy")

        wandb.log_artifact(artifact)


def evaluate_one_task(model, device, aux_train_loader, aux_test_loader):

    new_train_loader = aux_train_loader
    new_test_loader = aux_test_loader

    with torch.no_grad():
        # Evaluate on test set
        model.eval()
        train_correct = 0
        train_total = 0
        with torch.no_grad():
            for data, target in new_train_loader:
                data, target = data.to(device), target.to(device)
            
                logits = model.predict(data)

                preds = logits.argmax(dim=1)
                train_correct += preds.eq(target).sum().item()
                train_total += data.size(0)
        model.eval()
        test_correct = 0
        test_total = 0
        with torch.no_grad():
            for data, target in new_test_loader:
                data, target = data.to(device), target.to(device)
                logits = model.predict(data)
                preds = logits.argmax(dim=1)
                test_correct += preds.eq(target).sum().item()
                test_total += data.size(0)
    model.train()

    return train_correct / train_total, test_correct / test_total



def BWA_one_task(model, cfg, agent_config, arch_config, task_config, device, train_loader, test_loader, epochs, eval_interval):
    """
    Train a given model on one random-label CIFAR10 task for 'epochs' epochs,
    then return the final test accuracy.
    """
    model.to(device)
    model.train()
    batch_counter = 0
    train_reslut = []
    test_result = []
    
    for ep in range(epochs):
        for data, target in train_loader:
            model.to(device)
            model.train()
            data, target = data.to(device), target.to(device)
            logits, metrics = model.step(data, target)
            loss = metrics['curr_train_loss']
            
            if batch_counter % eval_interval == 0:
                eval_model = create_model(cfg, agent_config, arch_config, task_config, device)
                eval_model.load_state_dict(model.state_dict())
                if cfg.agent.agent_type == "NeuroSyncAgent":
                    eval_model.load_consolidated_weights(model.model._ema_params)
                eval_model.to(device)
                eval_model.eval()
                train_acc, test_acc = evaluate_one_task(eval_model, device=device, aux_train_loader=train_loader, aux_test_loader=test_loader)
                train_reslut.append(train_acc)
                test_result.append(test_acc)
            
            batch_counter += 1

    return train_reslut, test_result

def train_one_task(task_id, model, cfg, agent_config, arch_config, task_config, device, train_loader, test_loader, epochs, reoccured, global_counter, eval_interval, wandb_enable, aux_train_loader, aux_test_loader):
    """
    Train a given model on one random-label CIFAR10 task for 'epochs' epochs,
    then return the final test accuracy.
    """
    model.to(device)
    model.train()
    batch_counter = 0
    few_shot_train_acc = 0
    few_shot_test_acc = 0
    final_train_acc = 0.0
    final_test_acc = 0.0
    
    
    train_result = []
    test_result = []
    
    with tqdm(range(epochs), desc=f'training the model task {task_id}', disable=False) as pbar:
        for ep in range(epochs):
            model.to(device)
            for data, target in tqdm(train_loader, desc= 'training the model in one epoch', disable=True):
                data, target = data.to(device), target.to(device)
                logits, metrics = model.step(data, target)
                loss = metrics['curr_train_loss']
                
                if batch_counter % eval_interval == 0:
                    eval_model = create_model(cfg, agent_config, arch_config, task_config, device)
                    eval_model.load_state_dict(model.state_dict())
                    if cfg.agent.agent_type == "NeuroSyncAgent":
                        eval_model.load_consolidated_weights(model.model._ema_params) 
                    eval_model.to(device)
                    eval_model.eval()
                    train_acc, test_acc = evaluate_one_task(eval_model, device=device, aux_train_loader=aux_train_loader, aux_test_loader=aux_test_loader)
                    if wandb_enable:
                        log_info = {f'train_acc_{task_config.benchmark}':train_acc,
                                    f'test_acc_{task_config.benchmark}':test_acc}
                        wandb.log(log_info, step=global_counter)
                    
                    train_result.append(train_acc)
                    test_result.append(test_acc)
                    
                    global_counter += 1
                
                batch_counter += 1
            
            if reoccured and ep == 1:
                eval_model = create_model(cfg, agent_config, arch_config, task_config, device)
                eval_model.load_state_dict(model.state_dict())
                if cfg.agent.agent_type == "NeuroSyncAgent":
                    eval_model.load_consolidated_weights(model.model._ema_params)
                eval_model.eval()
                eval_model.to(device)
                few_shot_train_acc, few_shot_test_acc = evaluate_one_task(eval_model, device=device, aux_train_loader=aux_train_loader, aux_test_loader=aux_test_loader)

    final_train_acc, final_test_acc = evaluate_one_task(model, device=device, aux_train_loader=aux_train_loader, aux_test_loader=aux_test_loader)

    return few_shot_train_acc, few_shot_test_acc, final_train_acc, final_test_acc, train_result, test_result, metrics, global_counter


def create_model(cfg, agent_config, arch_config, task_config, device):
    # Create a new model
    if cfg.agent.agent_type == "BaseAgent":
        model = BaseAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
        
    elif cfg.agent.agent_type == "L2Agent":
        model = L2Agent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "ReDoAgent":
        model = ReDoAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "L2InitAgent":
        model = L2InitAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)

    elif cfg.agent.agent_type == "LayerNormAgent":
        model = LayerNormAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "NeuroSyncAgent":
        num_foward_pass_each_task = task_config.epochs * (task_config.limit / task_config.batch_size)
        agent_config.ema_decay = float(np.power(task_config.ema_target, 1/num_foward_pass_each_task))
        print('ema_decay:', agent_config.ema_decay)
        model = NeuroSyncAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
    
    elif cfg.agent.agent_type == "DeepFourierAgent":
        model = DeepFourierAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "CReLUAgent":
        model = CReLUAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "PReLUAgent":
        model = PReLUAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "CBPAgent":
        model = CBPAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)

    elif cfg.agent.agent_type == "EWCAgent":
        model = EWCAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
    
    elif cfg.agent.agent_type == "L2InitPlusEWCAgent":
        model = L2InitPlusEWCAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
    
    elif cfg.agent.agent_type == "ShrinkAndPerturbAgent":
        model = ShrinkAndPerturbAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
    
    else:
        raise ValueError(f"Unknown agent: {cfg.agent}")
    
    return model

def load_weights(cfg, agent_config, arch_config, task_config, device):
    model = create_model(cfg, agent_config, arch_config, task_config, device)
    old_state = torch.load(f"/localhome/srr8/projects/17/maml_model/{arch_config.arch_name}_{agent_config.agent_type}_{task_config.benchmark}_{cfg.meta_train_epoch}_init.pth", map_location="cuda")
    new_state = {}
    # TODO
    if arch_config.arch_name == 'MLP':
        fc_counter = 0
        for name, param in old_state.items():
            new_key = f'model.fc{fc_counter+1}.{name.split(".")[-1]}'
            new_state[new_key] = param
            if name.split('.')[-1] == 'bias':
                fc_counter += 1
        
        for name, param in model.named_parameters():
            if name in new_state:
                pass
            else:
                new_state[name] = param
        model.load_state_dict(new_state)
        model.to(device)

    elif arch_config.arch_name == 'CNN':
        cnn_counter = 0
        for name, param in old_state.items():
            if name.startswith('convs'):
                new_key = f'model.conv{cnn_counter+1}.{name.split(".")[-1]}'
                new_state[new_key] = param
                if name.split('.')[-1] == 'bias':
                    cnn_counter += 1
            elif name.startswith('fc'):
                new_key = f'model.fc1.{name.split(".")[-1]}'
                new_state[new_key] = param
        print(new_state.keys())
        
        for name, param in model.named_parameters():
            if name in new_state:
                pass
            else:
                new_state[name] = param
        model.load_state_dict(new_state)
        model.to(device)
    
    return model


# ------------------------------------------------------------------------
# 9) Main Experiment
# ------------------------------------------------------------------------
@hydra.main(config_path='configs/sl', config_name='config')
def main(cfg: DictConfig):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    set_seed(cfg.seed)

    agent_config = cfg.agent
    arch_config = cfg.arch
    task_config = cfg.task
    
    get_task_dataset, _, _ = configure_dataset(task_config=task_config, arch_config=arch_config, args=cfg)
    
    model = load_weights(cfg, agent_config, arch_config, task_config, device)
    model.get_optimizer()

    print(f"\n===== Training Model: {cfg.agent} =====")

    
    if cfg.agent == "NeuroSyncAgent":
        orchs_param_count, network_param_count = model.model.compute_total_params()
        print('Total number of trainable parameters:', orchs_param_count + network_param_count)
        print('Total number of trainable parameters in the network:', network_param_count)
        print('Percentage of trainable parameters in the network:', network_param_count / (orchs_param_count + network_param_count) * 100)
    else:
        print('Total number of trainable parameters:', model.model.compute_total_params())

    model.to(device)
    global_counter = 0
    exp_name = f'{model.__class__.__name__}_{arch_config.arch_name}_{agent_config.optimizer}_{agent_config.lr}_{task_config.benchmark}_transform={task_config.tranform}_{cfg.seed}'
    
    
    exp_name = 'C_MAML_' + exp_name
        
    if cfg.wandb:
        merged_config = {}
        merged_config.update(namespace_to_dict(task_config))
        merged_config.update(namespace_to_dict(agent_config))
        merged_config.update(namespace_to_dict(arch_config))
        merged_config['seed'] = cfg.seed  # keep args fields if you like
        wandb.init(
            project=cfg.proj_name,
            name= exp_name,
            group= task_config.benchmark,
            config=merged_config,
            save_code=True,
        )
    
    bt_train_matrix = [[[] for _ in range(task_config.num_tasks)] for _ in range(task_config.num_tasks)]
    bt_test_matrix =  [[[] for _ in range(task_config.num_tasks)] for _ in range(task_config.num_tasks)] 
    length_k = 0
    
    train_loader_list = []
    test_load_list = []
    for task_id in range(task_config.num_tasks):
        train_dataset_task, test_dataset_task, reoccured = get_task_dataset(task_id)

        g = torch.Generator()
        g.manual_seed(cfg.seed)

        if not isinstance(train_dataset_task, DataLoader):
            train_loader = DataLoader(train_dataset_task, batch_size=task_config.batch_size, shuffle=True, generator=g)
            aux_train_loader = DataLoader(train_dataset_task, batch_size=task_config.batch_size, shuffle=True, generator=g)
        
        if not isinstance(test_dataset_task, DataLoader):
            test_loader = DataLoader(test_dataset_task, batch_size=task_config.batch_size, shuffle=False)
            aux_test_loader = DataLoader(test_dataset_task, batch_size=task_config.batch_size, shuffle=False)
        
        train_loader_list.append(train_loader)
        test_load_list.append(test_loader)

        zero_shot_train_acc, zero_shot_test_acc = evaluate_one_task(model, device, train_loader, test_loader)
       
        few_shot_train_acc, few_shot_test_acc, final_train_acc, final_test_acc, train_result, test_result, metrics, global_counter = train_one_task(task_id, 
                                                                                                                                                    model,
                                                                                                                                                    cfg, agent_config, arch_config, task_config, device,
                                                                                                                                                    train_loader, 
                                                                                                                                                    test_loader,
                                                                                                                                                    epochs=task_config.epochs, 
                                                                                                                                                    reoccured= reoccured,
                                                                                                                                                    global_counter=global_counter, 
                                                                                                                                                    eval_interval=cfg.train_eval_interval,
                                                                                                                                                    wandb_enable=cfg.wandb,
                                                                                                                                                    aux_train_loader=aux_train_loader,
                                                                                                                                                    aux_test_loader=aux_test_loader)

        
        #model = load_weights(cfg, agent_config, arch_config, task_config, device)
        model.get_optimizer()

        if cfg.agent.agent_type in ['EWCAgent', 'L2InitPlusEWCAgent']:

            # Get completed task's test data.
            if isinstance(train_dataset_task, DataLoader):
                xs_list = []
                ys_list = []
                for batch in loader:
                    x, y = batch
                    xs_list.append(x)
                    ys_list.append(y)

                task_test_xs = torch.cat(xs_list, dim=0)   # shape: (N, ...)
                task_test_ys = torch.cat(ys_list, dim=0)   # shape: (N, ...)
            else:
                loader = DataLoader(train_dataset_task, batch_size=len(train_dataset_task), shuffle=False)
                task_test_xs, task_test_ys = next(iter(loader))

            # Shuffle data.
            dataset_len = len(task_test_xs)
            indices = np.arange(dataset_len)
            np.random.shuffle(indices)
            task_test_xs = task_test_xs[indices][:int(dataset_len/10)]
            task_test_ys = task_test_ys[indices][:int(dataset_len/10)]

            # Update Fisher matrix using this data.
            model.update_params_and_fisher(
                task_test_xs, task_test_ys, batch_size=task_config.batch_size)
        
        log_info = {}
        
        # if cfg.agent.agent_type == "NeuroSyncAgent":
        #     log_info_neuro_sync = model.model.plot_params()
        #     log_info = log_info_neuro_sync
        
        if cfg.monitor_backward_transfer:
            for i_task in range(task_id):
                if task_id == task_config.num_tasks - 1:
                    pass
                else:
                    if i_task < task_id - 1 :
                        continue

                g = torch.Generator()
                g.manual_seed(cfg.seed)
                train_loader = train_loader_list[i_task]
                test_loader = test_load_list[i_task]
                
                aux_model = create_model(cfg, agent_config, arch_config, task_config, device)
                aux_model.load_state_dict(model.state_dict())
                if cfg.agent.agent_type == "NeuroSyncAgent":
                    aux_model.load_consolidated_weights(model.model._ema_params)
                aux_model.get_optimizer()
                
                print(f'Backward adaptation task {task_id} to task {i_task}')
                train_reslut, test_result = BWA_one_task(aux_model, 
                                                         cfg, agent_config, arch_config, task_config, device,
                                                         train_loader, 
                                                         test_loader, 
                                                         cfg.num_epoch_backward_adaptation, 
                                                         eval_interval=cfg.bwa_eval_interval)

                length_k = len(train_reslut)
                print(f'Backward adaptation task {task_id} to task {i_task} len: {length_k}')
                bt_train_matrix[task_id][i_task] = train_reslut 
                bt_test_matrix[task_id][i_task] = test_result 
        
        if reoccured:
            log_info[f'reoccuring/train_acc_{task_config.benchmark}'] = final_train_acc
            log_info[f'reoccuring/test_acc_{task_config.benchmark}'] = final_test_acc
            log_info[f'reoccuring/few_epoch_train_acc_{task_config.benchmark}'] = few_shot_train_acc
            log_info[f'reoccuring/few_epoch_test_acc_{task_config.benchmark}'] = few_shot_test_acc
        
        if cfg.wandb:
            #breakpoint()
            log_info[f"final_train_acc_{task_config.benchmark}"] = final_train_acc
            log_info[f'final_test_acc_{task_config.benchmark}'] = final_test_acc
            log_info[f'zero_shot_train_acc_{task_config.benchmark}'] = zero_shot_train_acc
            log_info[f'zero_shot_test_acc_{task_config.benchmark}'] = zero_shot_test_acc
            
            wandb.log(log_info, step=global_counter)


        print(f" Task {task_id+1}/{task_config.num_tasks} -> test acc: {final_train_acc:.3f}")

    if cfg.monitor_backward_transfer:
        save_artifact(bt_train_matrix, f'{task_config.num_tasks}_{cfg.agent.agent_type}_train_backward', cfg.wandb, task_config.num_tasks, length_k)
        save_artifact(bt_test_matrix, f'{task_config.num_tasks}_{cfg.agent.agent_type}_test_backward', cfg.wandb, task_config.num_tasks, length_k)


if __name__ == "__main__":
    main()





