# -*- coding: utf-8 -*-
"""
Rotated MNIST with Experience Replay (1K buffer, class-balanced, 50% batch)
and Oracle Optimizer Reset
"""

import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, RandomSampler, ConcatDataset
from torch.distributions import Categorical
import numpy as np
import random
import pickle

device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Load data
class RotatableMNIST(datasets.MNIST):
    def __init__(self, root, train=True, download=False, transform=None):
        super().__init__(root=root, train=train, download=download,
                         transform=transform)
        self._base_transform = transform

    def rot(self, angle):
        """
        Rotations
        """
        if isinstance(self._base_transform, transforms.Compose):
            base_ts = self._base_transform.transforms
        elif self._base_transform:
            base_ts = [self._base_transform]
        else:
            base_ts = []
        new_transform = transforms.Compose(
            [transforms.Lambda(lambda img: TF.rotate(img, angle))] + base_ts
        )
        return RotatableMNIST(
            root=self.root,
            train=self.train,
            download=self.download,
            transform=new_transform
        )

base_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
# train dataset
train_ds = RotatableMNIST(root="./data", train=True,  download=True, transform=base_transform)
# test dataset
test_ds  = RotatableMNIST(root="./data", train=False, download=True, transform=base_transform)

## Network Model
# --- Layers ---
class GainModLayer(nn.Module):
    def __init__(self, input_dim, output_dim, g0=1.0, gamma=0.9, eta: float = 0.0):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(output_dim, input_dim) * 0.01)
        self.bias   = nn.Parameter(torch.zeros(output_dim))
        self.gains  = nn.Parameter(torch.ones(output_dim), requires_grad=False)
        self.g0     = g0
        self.gamma  = gamma
        self.eta    = eta

    def decay(self, H_drive: float = 0.0):
        with torch.no_grad():
            self.gains.mul_(self.gamma).add_((1 - self.gamma) * self.g0).add_(self.eta * H_drive)
    
    def forward(self, x, H_drive: float = 0.0):
        if self.training:
            self.decay(H_drive)
        z = (x @ self.weight.t()) * self.gains + self.bias      
        return torch.relu(z)


class GainModOut(nn.Module):
    def __init__(self, input_dim, output_dim, g0=1.0, gamma=0.9, eta: float = 0.0):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(output_dim, input_dim) * 0.05)
        self.bias   = nn.Parameter(torch.zeros(output_dim))
        self.gains  = nn.Parameter(torch.ones(output_dim), requires_grad=False)
        self.g0     = g0
        self.gamma  = gamma
        self.eta    = eta
    
    def decay(self, H_drive: float = 0.0):
        with torch.no_grad():
            self.gains.mul_(self.gamma).add_((1 - self.gamma) * self.g0).add_(self.eta * H_drive)
     
    def forward(self, x, H_drive: float = 0.0):
        if self.training:
            self.decay(H_drive)
        z = (x @ self.weight.t()) * self.gains + self.bias      
        return z

# --- Network ---
class GainModNet(nn.Module):
    def __init__(self, input_dim=28*28, hidden_dim=400, output_dim=10,
                 g0=1.0, gamma=0.9, eta: float = 0.0):
        super().__init__()
        self.layer1 = GainModLayer(input_dim,  hidden_dim, g0, gamma, eta)
        self.layer2 = GainModLayer(hidden_dim, hidden_dim, g0, gamma, eta)
        self.out    = GainModOut(hidden_dim,  output_dim, g0, gamma, eta)

    def forward(self, x, H_drive: float = 0.0):
        x  = x.view(x.size(0), -1)
        h1 = self.layer1(x, H_drive)
        h2 = self.layer2(h1, H_drive)
        y  = self.out(h2, H_drive)
        return y

# ======================================================================
# Experience Replay: class-balanced buffer (1K capacity, 50% batch)
# ======================================================================
def build_class_balanced_buffer(train_ds, buffer_size):
    """
    Build a class-balanced replay buffer over all 10 digit classes (0-9).
    Approximately buffer_size / 10 samples per class.
    """
    targets = torch.tensor(train_ds.targets)
    classes = torch.unique(targets).tolist()  # digits 0..9

    indices_by_class = {}
    for c in classes:
        mask_c = (targets == c)
        idx_c = torch.nonzero(mask_c, as_tuple=False).squeeze()
        if idx_c.numel() > 0:
            indices_by_class[c] = idx_c.cpu().numpy()

    n_classes = len(indices_by_class)
    if n_classes == 0:
        return [], []

    per_class_cap = max(1, buffer_size // n_classes)

    buffer_x, buffer_y = [], []
    for c, idxs in indices_by_class.items():
        if len(idxs) <= per_class_cap:
            chosen = idxs
        else:
            chosen = np.random.choice(idxs, size=per_class_cap, replace=False)
        for i in chosen:
            x, y = train_ds[int(i)]
            buffer_x.append(x.cpu())
            buffer_y.append(torch.tensor(int(y), dtype=torch.long))

    return buffer_x, buffer_y

## Continual evaluation with ER
def continual_train(model, optimizer, criterion, contexts, batch, ctx_iter,
                    rho_eval, train_ds, test_ds, device, mode,
                    reset_optimizer=False,
                    replay_buffer_size=1000,
                    replay_fraction=0.5):

    # --- DATA STORAGE ---
    hist = {k: [] for k in [
        "acc_train", "loss_train",
        "acc_test",  "loss_test",
        "w1", "w2", "w_out",
        "gain1", "gain2", "gain_out"
    ]}
    eval_loaders = {}
    step = 0
    H_prev = 0.0

    # ER buffer (global)
    buffer_x = []  # list of [1,28,28] tensors on CPU
    buffer_y = []  # list of labels on CPU

    # --- CONTEXT LOOP ---
    for task_id, ctx in enumerate(contexts):
        print(f"\nContext {task_id+1}/{len(contexts)}: {ctx}")

        # --- ORACLE RESET ---
        if reset_optimizer and task_id > 0:
            optimizer.state.clear()

        # --------------------- TRAIN LOADER --------------------------
        train_rots = [train_ds.rot(angle) for angle in ctx]
        train_subsets = ConcatDataset(train_rots)
        sampler = RandomSampler(train_subsets, replacement=True,
                                num_samples=ctx_iter * batch)
        train_loader = DataLoader(train_subsets, batch_size=batch, sampler=sampler)
        train_iter = iter(train_loader)

        # --------------------- TEST LOADER ---------------------------
        test_rots = [test_ds.rot(angle) for angle in [0]]  # Always test on 0°
        test_subsets = ConcatDataset(test_rots)
        test_loader  = DataLoader(test_subsets, batch_size=batch, shuffle=False)
        eval_loaders[task_id] = test_loader
        
        # --------------------- TRAINING ------------------------------
        model.train()
        for _ in range(ctx_iter):
            inputs, targets = next(train_iter)
            inputs, targets = inputs.to(device), targets.to(device)

            # ========== ER: 50% new / 50% replay ==========
            if len(buffer_x) > 0 and replay_fraction > 0.0:
                replay_size  = int(batch * replay_fraction)
                replay_size  = min(replay_size, len(buffer_x))
                current_size = batch - replay_size

                # current samples
                if current_size > 0:
                    perm_cur = torch.randperm(inputs.size(0))[:current_size]
                    cur_x = inputs[perm_cur]
                    cur_y = targets[perm_cur]
                else:
                    cur_x = inputs[:0]
                    cur_y = targets[:0]

                # replay samples (uniform from buffer)
                idxs = np.random.choice(len(buffer_x), size=replay_size, replace=False)
                rep_x = torch.stack([buffer_x[i] for i in idxs]).to(device)
                rep_y = torch.stack([buffer_y[i] for i in idxs]).to(device)

                mixed_x = torch.cat([cur_x, rep_x], dim=0)
                mixed_y = torch.cat([cur_y, rep_y], dim=0)
            else:
                mixed_x, mixed_y = inputs, targets

            # --- forward + backward + update ---
            optimizer.zero_grad()
            outputs = model(mixed_x.view(mixed_x.size(0), -1), H_prev)
            loss = criterion(outputs, mixed_y)
            loss.backward()
            optimizer.step()
            
            # compute entropy 
            if mode == "gain-H":
                with torch.no_grad():
                    dist = Categorical(logits=outputs)
                    H_prev = dist.entropy().mean().item()

            # --- train metrics ---
            hist["acc_train"].append((outputs.argmax(1) == mixed_y).float().mean().item())
            hist["loss_train"].append(loss.item())
            # --- weights and gains stats ---
            hist["w1"].append(model.layer1.weight.data.abs().mean().item())
            hist["w2"].append(model.layer2.weight.data.abs().mean().item())
            hist["w_out"].append(model.out.weight.data.abs().mean().item())
            hist["gain1"].append(model.layer1.gains.mean().item())
            hist["gain2"].append(model.layer2.gains.mean().item())
            hist["gain_out"].append(model.out.gains.mean().item())

            # ---------------- CONTINUAL EVAL --------------------------
            if step % rho_eval == 0:
                model.eval()
                total, correct, loss_sum = 0, 0, 0.0
                with torch.no_grad():
                    for ev_loader in eval_loaders.values():
                        for x_t, y_t in ev_loader:
                            x_t, y_t = x_t.to(device), y_t.to(device)
                            logits = model(x_t.view(x_t.size(0), -1), H_prev)
                            loss_ev = criterion(logits, y_t)

                            loss_sum += loss_ev.item() * y_t.size(0)
                            correct  += (logits.argmax(1) == y_t).sum().item()
                            total    += y_t.size(0)
                            
                # --- test metrics ---
                hist["acc_test"].append(100.0 * correct / total)
                hist["loss_test"].append(loss_sum / total)
                model.train()

            step += 1

        # --------- update ER buffer at end of context (class-balanced over digits) ----------
        buffer_x, buffer_y = build_class_balanced_buffer(
            train_ds, buffer_size=replay_buffer_size
        )
        print(f"Replay buffer updated after task {task_id+1}: "
              f"{len(buffer_x)} samples.")

    return hist

# --- Hyperparameters ---
batch = 128 
lr_weights = 0.01
rho_eval = 1
ctx_iter = 400 
contexts = [[0], [80], [160]]
criterion = nn.CrossEntropyLoss()

## Simulation
def run_all_models_nSims(nSims, base_seed=0):
    all_results = []

    for sim in range(nSims):
        seed = base_seed + sim
        # --- Seeds ---
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

        print(f"\n=== Sim {sim+1}/{nSims}, seed={seed} ===")
        
        # -------------------------------------------------------
        # 1. ENTROPY GAIN MODEL (NGM-SGD)
        # -------------------------------------------------------
        print("ENTROPY GAIN MODEL")
        hgainModel = GainModNet(input_dim=28*28, hidden_dim=400, output_dim=10,
                                g0=1.0, gamma=0.9, eta=0.5).to(device)
        hgainOptimizer = optim.SGD([
           {'params': [param for name, param in hgainModel.named_parameters() if 'weight' in name],
            'lr': lr_weights},
           {'params': [param for name, param in hgainModel.named_parameters() if 'weight' not in name],
            'lr': 0},
        ], momentum=0.0)
        hgain_hist = continual_train(
            hgainModel, hgainOptimizer, criterion,
            contexts, batch, ctx_iter, rho_eval,
            train_ds, test_ds, device,
            mode="gain-H", reset_optimizer=False,
            replay_buffer_size=1000, replay_fraction=0.5
        )
        
        # -------------------------------------------------------
        # 2. MSGD (Standard)
        # -------------------------------------------------------
        print("MOMENTUM SGD")
        msgdModel = GainModNet(input_dim=28*28, hidden_dim=400, output_dim=10,
                               g0=1.0, gamma=0.0, eta=0.0).to(device)
        msgdOptimizer = optim.SGD([
           {'params': [param for name, param in msgdModel.named_parameters() if 'weight' in name],
            'lr': 10*lr_weights},
           {'params': [param for name, param in msgdModel.named_parameters() if 'weight' not in name],
            'lr': 0},
        ], momentum=0.9)
        msgd_hist = continual_train(
            msgdModel, msgdOptimizer, criterion,
            contexts, batch, ctx_iter, rho_eval,
            train_ds, test_ds, device,
            mode="no-gain", reset_optimizer=False,
            replay_buffer_size=1000, replay_fraction=0.5
        )

        # -------------------------------------------------------
        # 3. MSGD RESET (Oracle)
        # -------------------------------------------------------
        print("MOMENTUM SGD RESET")
        msgdResetModel = GainModNet(input_dim=28*28, hidden_dim=400, output_dim=10,
                                    g0=1.0, gamma=0.0, eta=0.0).to(device)
        msgdResetOptimizer = optim.SGD([
           {'params': [param for name, param in msgdResetModel.named_parameters() if 'weight' in name],
            'lr': 10*lr_weights},
           {'params': [param for name, param in msgdResetModel.named_parameters() if 'weight' not in name],
            'lr': 0},
        ], momentum=0.9)
        msgd_reset_hist = continual_train(
            msgdResetModel, msgdResetOptimizer, criterion,
            contexts, batch, ctx_iter, rho_eval,
            train_ds, test_ds, device,
            mode="no-gain", reset_optimizer=True,
            replay_buffer_size=1000, replay_fraction=0.5
        )

        # -------------------------------------------------------
        # 4. ADAM (Standard)
        # -------------------------------------------------------
        print("VANILLA ADAM")
        adamModel = GainModNet(input_dim=28*28, hidden_dim=400, output_dim=10,
                               g0=1.0, gamma=0.0, eta=0.0).to(device)
        adamOptimizer = optim.Adam([
            {'params': [param for name, param in adamModel.named_parameters() if 'weight' in name],
             'lr': 0.1*lr_weights},
            {'params': [param for name, param in adamModel.named_parameters() if 'weight' not in name],
             'lr': 0},
        ])
        adam_hist = continual_train(
            adamModel, adamOptimizer, criterion,
            contexts, batch, ctx_iter, rho_eval,
            train_ds, test_ds, device,
            mode="no-gain", reset_optimizer=False,
            replay_buffer_size=1000, replay_fraction=0.5
        )
        
        # -------------------------------------------------------
        # 5. ADAM RESET (Oracle)
        # -------------------------------------------------------
        print("VANILLA ADAM RESET")
        adamResetModel = GainModNet(input_dim=28*28, hidden_dim=400, output_dim=10,
                                    g0=1.0, gamma=0.0, eta=0.0).to(device)
        adamResetOptimizer = optim.Adam([
            {'params': [param for name, param in adamResetModel.named_parameters() if 'weight' in name],
             'lr': 0.1*lr_weights},
            {'params': [param for name, param in adamResetModel.named_parameters() if 'weight' not in name],
             'lr': 0},
        ])
        adam_reset_hist = continual_train(
            adamResetModel, adamResetOptimizer, criterion,
            contexts, batch, ctx_iter, rho_eval,
            train_ds, test_ds, device,
            mode="no-gain", reset_optimizer=True,
            replay_buffer_size=1000, replay_fraction=0.5
        )

        # -------------------------------------------------------
        # 6. SGD (Vanilla)
        # -------------------------------------------------------
        print("SGD")
        sgdModel = GainModNet(input_dim=28*28, hidden_dim=400, output_dim=10,
                               g0=1.0, gamma=0.0, eta=0.0).to(device)
        sgdOptimizer = optim.SGD([
           {'params': [param for name, param in sgdModel.named_parameters() if 'weight' in name],
            'lr': 10*lr_weights},
           {'params': [param for name, param in sgdModel.named_parameters() if 'weight' not in name],
            'lr': 0},
        ], momentum=0.0)
        sgd_hist = continual_train(
            sgdModel, sgdOptimizer, criterion,
            contexts, batch, ctx_iter, rho_eval,
            train_ds, test_ds, device,
            mode="no-gain", reset_optimizer=False,
            replay_buffer_size=1000, replay_fraction=0.5
        )

        # Storage
        results = {
           "ENTROPY GAIN": hgain_hist,
           "ADAM":         adam_hist,
           "ADAM reset":   adam_reset_hist,
           "MSGD":         msgd_hist,
           "MSGD reset":   msgd_reset_hist,
           "SGD":          sgd_hist,
        }
        all_results.append({"seed": seed, "results": results})

    return all_results

## Run
all_sims = run_all_models_nSims(nSims=5, base_seed=42)

# Save
with open("ER_ex4_rotatedMNIST_1K.pkl", "wb") as f:
    pickle.dump(all_sims, f)
print("Done! Saved to ER_ex4_rotatedMNIST_1K.pkl")
