# -*- coding: utf-8 -*-
"""
Split mini-ImageNet Reset
"""

import os
import random
import pickle
from typing import List, Dict, Tuple, Union, Iterable

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import relu, avg_pool2d
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, RandomSampler
from torchvision import transforms
from PIL import Image
from torch.distributions import Categorical

device = 'cuda' if torch.cuda.is_available() else 'cpu'
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.benchmark = True

## Load data
PKL_PATH = r".\miniImagenet\miniImagnet.pkl" 

train_tf = transforms.Compose([
    transforms.ToPILImage(),
    #transforms.RandomResizedCrop(84, scale=(0.9, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225)),
])
test_tf = transforms.Compose([
    transforms.ToPILImage(),
    #transforms.Resize(96),
    #transforms.CenterCrop(84),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225)),
])

class XYDataset(Dataset):
    def __init__(self, x, y, transform=None):
        self.data = x
        self.targets = y
        self.transform = transform
    def __len__(self): return len(self.data)
    def __getitem__(self, idx):
        x, y = self.data[idx], int(self.targets[idx])
        if self.transform is not None:
            x = self.transform(x)
        return x, y

def load_miniimgnet_from_pkl(pkl_path: str,
                             train_transform=None,
                             test_transform=None) -> Tuple[Dataset, Dataset]:

    import pickle
    with open(pkl_path, "rb") as f:
        dataset = pickle.load(f)

    X = np.asarray(dataset["data"])
    y = np.asarray(dataset["labels"])
    assert X.ndim == 4 and X.shape[-1] == 3, "Expect RGB images HxWxC"
    assert len(X) == len(y), "X and y must align"

    classes = sorted(set(y.tolist()))
    if len(classes) != 100:
        print(f"Warning: found {len(classes)} classes (expected 100)")

    train_x, train_y, test_x, test_y = [], [], [], []
    for cls in classes:
        idx = np.where(y == cls)[0]
        idx = np.sort(idx)
        if len(idx) < 600:
            raise ValueError(f"Class {cls} has only {len(idx)} images (<600).")
        train_idx = idx[:500]
        test_idx  = idx[500:600]
        train_x.append(X[train_idx]); train_y.append(y[train_idx])
        test_x.append(X[test_idx]);   test_y.append(y[test_idx])

    train_x = np.concatenate(train_x, axis=0)
    train_y = np.concatenate(train_y, axis=0)
    test_x  = np.concatenate(test_x,  axis=0)
    test_y  = np.concatenate(test_y,  axis=0)

    rng = np.random.default_rng(SEED)
    p = rng.permutation(len(train_x)); train_x, train_y = train_x[p], train_y[p]
    p = rng.permutation(len(test_x));  test_x,  test_y  = test_x[p],  test_y[p]

    return XYDataset(train_x, train_y, transform=train_transform), \
           XYDataset(test_x,  test_y,  transform=test_transform)

# Train and test datasets
if os.path.exists(PKL_PATH):
    train_ds, test_ds = load_miniimgnet_from_pkl(
        PKL_PATH,
        train_transform=train_tf,
        test_transform=test_tf
    )
else:
    print(f"ERROR: Dataset not found at {PKL_PATH}")
    train_ds, test_ds = None, None

## Network Model
# --- Conv ---
def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes,
                     kernel_size=3, stride=stride, 
                     padding=1, bias=False)

# --- Slim ResNet ---
class BasicBlock(nn.Module):
    expansion = 1
    
    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1   = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)
        self.bn2   = nn.BatchNorm2d(planes)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes * self.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, 
                          planes * self.expansion, 
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * self.expansion)
            )
    def forward(self, x):
        out = relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return relu(out)

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, nf, global_pooling, input_size):
        super().__init__()
        self.global_pooling = global_pooling
        
        self.in_planes = nf
        self.conv1 = conv3x3(3, nf)
        self.bn1   = nn.BatchNorm2d(nf)
        self.layer1 = self._make_layer(block, nf,     num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2)

        with torch.no_grad():
            dummy = torch.zeros(1, *input_size)
            out = self.forward_feats(dummy)
            if global_pooling:
                out = avg_pool2d(out, out.shape[-1])
            self.feature_size = out.view(1, -1).size(1)

    def _make_layer(self, block, planes, n_blocks, stride):
        strides = [stride] + [1]*(n_blocks-1)
        layers = []
        for s in strides:
            layers.append(block(self.in_planes, planes, s))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward_feats(self, x):
        out = relu(self.bn1(self.conv1(x)))
        out = self.layer1(out); out = self.layer2(out)
        out = self.layer3(out); out = self.layer4(out)
        return out

    def forward(self, x):
        out = self.forward_feats(x)
        if self.global_pooling:
            out = avg_pool2d(out, out.shape[-1])
        return out.view(out.size(0), -1)

def ResNet18feat(input_size, nf=20, global_pooling=True):
    return ResNet(BasicBlock, [2,2,2,2], nf, global_pooling, input_size)

# --- Head ---
class GainModLinear(nn.Module):
    def __init__(self, in_dim, out_dim, g0=1.0, gamma=0.9, eta=0.0):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_dim, in_dim) * 0.05)
        self.bias   = nn.Parameter(torch.zeros(out_dim))
        self.gains  = nn.Parameter(torch.ones(out_dim) * g0, requires_grad=False)
        self.g0, self.gamma, self.eta = g0, gamma, eta
    @torch.no_grad()
    def decay(self, H_drive: float = 0.0):
        self.gains.mul_(self.gamma).add_((1 - self.gamma) * self.g0).add_(self.eta * H_drive)
    def forward(self, x, H_drive: float = 0.0):
        if self.training:
            self.decay(H_drive)
        W_scaled = self.weight * self.gains.view(-1, 1)
        return F.linear(x, W_scaled, self.bias)

# --- Network ---
class GainSlimResNet18(nn.Module):
    def __init__(self, input_size=(3,84,84), nf=20, output_dim=100, g0=1.0, gamma=0.9, eta=0.0):
        super().__init__()
        self.backbone = ResNet18feat(input_size, nf, global_pooling=True)
        feat_dim = self.backbone.feature_size
        self.fc  = GainModLinear(feat_dim, output_dim, g0=g0, gamma=gamma, eta=eta)
    def forward(self, x, H_prev: float = 0.0):
        feats = self.backbone(x)
        return self.fc(feats, H_prev)
    
## Continual evaluation
def continual_train(model, optimizer, criterion,
                    contexts: List[List[int]],
                    batch: int, ctx_iter: int, rho_eval: int,
                    train_ds, test_ds, device, mode: str,
                    test_ctx: List[int],
                    reset_optimizer: bool = False) -> Dict[str, list]:

    train_targets = torch.tensor(train_ds.targets)
    test_targets  = torch.tensor(test_ds.targets)

    hist = {k: [] for k in ["acc_train","loss_train","acc_test","loss_test","w_out","gain_out"]}
    eval_loaders = {}
    step = 0
    H_prev = 0.0

    for task_id, ctx in enumerate(contexts):
        print(f"\nContext {task_id+1}/{len(contexts)}: {ctx}")

        # --- ORACLE RESET ---
        if reset_optimizer and task_id > 0:
            optimizer.state.clear()

        # --- TRAIN  ---
        mask_train   = torch.isin(train_targets, torch.tensor(ctx))
        idx_train    = torch.nonzero(mask_train, as_tuple=False).squeeze()
        subset_train = Subset(train_ds, idx_train)
        train_loader = DataLoader(
            subset_train,
            batch_size=batch,
            sampler=RandomSampler(subset_train, replacement=True, num_samples=ctx_iter*batch),
        )
        train_iter = iter(train_loader)

        # --- TEST  ---
        mask_test   = torch.isin(test_targets, torch.tensor(test_ctx))
        idx_test    = torch.nonzero(mask_test, as_tuple=False).squeeze()
        subset_test = Subset(test_ds, idx_test)
        test_loader = DataLoader(
            subset_test, batch_size=batch, shuffle=False,
        )
        eval_loaders[task_id] = test_loader

        # --- TRAINING ---
        model.train()
        for _ in range(ctx_iter):
            inputs, targets = next(train_iter)
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs, H_prev)           
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if mode == "gain-H":
                with torch.no_grad():
                    H_prev = Categorical(logits=outputs).entropy().mean().item()

            # train
            with torch.no_grad():
                acc_train = (outputs.argmax(1) == targets).float().mean().item()

            hist["acc_train"].append(acc_train)
            hist["loss_train"].append(loss.item())
            hist["w_out"].append(model.fc.weight.data.abs().mean().item())
            hist["gain_out"].append(model.fc.gains.mean().item())

            # continual evaluation
            if step % rho_eval == 0:
                model.eval()
                total, correct, loss_sum = 0, 0, 0.0
                with torch.no_grad():
                    for ev_loader in eval_loaders.values():
                        for x_t, y_t in ev_loader:
                            x_t, y_t = x_t.to(device), y_t.to(device)
                            logits = model(x_t, H_prev)            
                            loss_ev = criterion(logits, y_t)        
                            loss_sum += loss_ev.item() * y_t.size(0)
                            correct  += (logits.argmax(1) == y_t).sum().item()
                            total    += y_t.size(0)
                hist["acc_test"].append(100.0 * correct / total)
                hist["loss_test"].append(loss_sum / total)
                model.train()

            step += 1

    return hist

# -----------------------------
# contexts (Split Mini-ImageNet)
# -----------------------------
def build_contexts(num_classes=100, classes_per_task=5, seed=SEED, num_contexts=None):
    rng = np.random.default_rng(seed)
    perm = rng.permutation(num_classes).tolist()
    blocks = [perm[i:i+classes_per_task] for i in range(0, num_classes, classes_per_task)]
    if num_contexts is not None:
        blocks = blocks[:num_contexts]
    # cumulative
    contexts, cum = [], []
    for b in blocks:
        cum = cum + b
        contexts.append(cum.copy())
    return contexts

## Hyperparameters
batch     = 256
ctx_iter  = 200
rho_eval  = 1
lr_weights = 0.01

contexts = build_contexts(num_classes=100, classes_per_task=10, 
                          seed=SEED, num_contexts=5)
test_ctx = contexts[0]
criterion = nn.CrossEntropyLoss()

## Simulation

def run_all_models_nSims(nSims, base_seed=0):
    all_results = []

    for sim in range(nSims):
        seed = base_seed + sim
        # --- Seeds ---
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

        print(f"\n=== Sim {sim+1}/{nSims}, seed={seed} ===")

        # -------------------------------------------------------
        # 1. ENTROPY GAIN MODEL (NGM-SGD)
        # -------------------------------------------------------
        print("ENTROPY GAIN MODEL")
        hgainModel = GainSlimResNet18(output_dim=100, nf=20, g0=1.0, gamma=0.9, eta=0.1).to(device)
        hgainOptimizer = optim.SGD([
            {'params': [p for n,p in hgainModel.named_parameters() if 'gains' not in n and not n.startswith('fc.')], 'lr': lr_weights, 'momentum': 0.9},
            {'params': [p for n,p in hgainModel.named_parameters() if n == 'fc.weight'], 'lr': lr_weights, 'momentum': 0.0}
        ])
        hgain_hist = continual_train(hgainModel, hgainOptimizer, criterion,
                                      contexts, batch, ctx_iter, rho_eval,
                                      train_ds, test_ds, device, "gain-H", test_ctx, reset_optimizer=False)

        
        # -------------------------------------------------------
        # 2. MSGD (Standard)
        # -------------------------------------------------------
        print("MOMENTUM SGD")
        msgdModel = GainSlimResNet18(output_dim=100, nf=20, g0=1.0, gamma=0.0, eta=0.0).to(device)
        msgdOptimizer = optim.SGD(
            [{'params': [p for n,p in msgdModel.named_parameters() if 'gains' not in n and n != 'fc.bias'],'lr': lr_weights}],
            momentum=0.9
        )
        msgd_hist = continual_train(msgdModel, msgdOptimizer, criterion, contexts, batch, ctx_iter, rho_eval,
                                    train_ds, test_ds, device, "no-gain", test_ctx, reset_optimizer=False)
        
        # -------------------------------------------------------
        # 3. MSGD RESET
        # -------------------------------------------------------
        print("MOMENTUM SGD RESET")
        msgdResetModel = GainSlimResNet18(output_dim=100, nf=20, g0=1.0, gamma=0.0, eta=0.0).to(device)
        msgdResetOptimizer = optim.SGD(
            [{'params': [p for n,p in msgdResetModel.named_parameters() if 'gains' not in n and n != 'fc.bias'],'lr': lr_weights}],
            momentum=0.9
        )
        msgd_reset_hist = continual_train(msgdResetModel, msgdResetOptimizer, criterion, contexts, batch, ctx_iter, rho_eval,
                                          train_ds, test_ds, device, "no-gain", test_ctx, reset_optimizer=True)

        # -------------------------------------------------------
        # 4. VANILLA ADAM (Standard)
        # -------------------------------------------------------
        print("VANILLA ADAM")
        adamModel = GainSlimResNet18(output_dim=100, nf=20, g0=1.0, gamma=0.0, eta=0.0).to(device)
        adamOptimizer = optim.Adam([
            {'params': [p for n,p in adamModel.named_parameters() if 'gains' not in n and n != 'fc.bias'],
             'lr': 0.1*lr_weights}
        ])
        adam_hist = continual_train(adamModel, adamOptimizer, criterion, contexts, batch, ctx_iter, rho_eval,
                                    train_ds, test_ds, device, "no-gain", test_ctx, reset_optimizer=False)

        # # -------------------------------------------------------
        # 5. VANILLA ADAM RESET
        # -------------------------------------------------------
        print("VANILLA ADAM RESET")
        adamResetModel = GainSlimResNet18(output_dim=100, nf=20, g0=1.0, gamma=0.0, eta=0.0).to(device)
        adamResetOptimizer = optim.Adam([
            {'params': [p for n,p in adamResetModel.named_parameters() if 'gains' not in n and n != 'fc.bias'],
             'lr': 0.1*lr_weights}
        ])
        adam_reset_hist = continual_train(adamResetModel, adamResetOptimizer, criterion, contexts, batch, ctx_iter, rho_eval,
                                          train_ds, test_ds, device, "no-gain", test_ctx, reset_optimizer=True)
        
        # -------------------------------------------------------
        # 6. SGD (Standard)
        # -------------------------------------------------------
        print("SGD")
        sgdModel = GainSlimResNet18(output_dim=100, nf=20, g0=1.0, gamma=0.0, eta=0.0).to(device)
        sgdOptimizer = optim.SGD(
            [{'params': [p for n,p in sgdModel.named_parameters() if 'gains' not in n and n != 'fc.bias'],'lr': 10*lr_weights}],
            momentum=0.0
        )
        sgd_hist = continual_train(sgdModel, sgdOptimizer, criterion, contexts, batch, ctx_iter, rho_eval,
                                   train_ds, test_ds, device, "no-gain", test_ctx, reset_optimizer=False)

        # Storage
        results = {
            "ENTROPY GAIN": hgain_hist,
            "ADAM":         adam_hist,
            "ADAM reset":   adam_reset_hist,
            "MSGD":         msgd_hist,
            "MSGD reset":   msgd_reset_hist,
            "SGD":          sgd_hist,
        }
        all_results.append({"seed": seed, "results": results})

    return all_results

## Run
if train_ds is not None:
    all_sims = run_all_models_nSims(nSims=5, base_seed=42)

    # Save
    with open("Rex3_splitminiImagenet_reset.pkl", "wb") as f:
        pickle.dump(all_sims, f)
    print("Done! Saved to Rex3_splitminiImagenet_reset.pkl")
else:
    print("Skipping run due to missing dataset.")
