import torch
import logging
from tqdm import tqdm
from time import sleep
import argparse 
from utils import save_config, save_model2, get_logger
import os
from torch.cuda.amp.autocast_mode import autocast
from typing import List, Union, Literal
from contextlib import contextmanager
import random
from torch.func import jvp, vmap
import timeit
from jacobian import JacobianReg
import torchmetrics
from torchmetrics import MetricCollection, Accuracy, Precision, Recall, F1Score, MeanMetric
import math


def myjacobian(y: torch.Tensor, x: torch.Tensor, need_higher_grad=True) -> torch.Tensor:
    if x.is_cuda:
        y = y.cuda()

    (Jac,) = torch.autograd.grad(
        outputs=(y.flatten(),),
        inputs=(x,),
        grad_outputs=(torch.eye(torch.numel(y)).to(x.device),),
        create_graph=need_higher_grad,
        allow_unused=True,
        is_grads_batched=True
    )
    if Jac is None:
        Jac = torch.zeros(size=(y.shape + x.shape))
    else:
        Jac = Jac.detach()
        # print(Jac.shape,y.shape,x.shape)

    return Jac

def batched_jacobian(batched_y:torch.Tensor,batched_x:torch.Tensor,need_higher_grad = True) -> torch.Tensor:

    sumed_y = batched_y.sum(dim = 0) 
    J = myjacobian(sumed_y,batched_x,need_higher_grad) 
    
    dims = list(range(J.dim()))
    dims[0],dims[sumed_y.dim()] = dims[sumed_y.dim()],dims[0]
    J = J.permute(dims = dims) 
    return J

def dict2str(dic):
    return "\t".join([f"{k}={v:.4f}" for k,v in dic.items()])

@contextmanager
def eval_context(net: torch.nn.Module):
    """Temporarily switch to evaluation mode."""
    istrain = net.training
    try:
        if istrain:
            net.eval()
        yield net
    finally:
        if istrain:
            net.train()

@contextmanager
def train_context(net: torch.nn.Module):
    """Temporarily switch to training mode."""
    istrain = net.training
    try:
        if not istrain:
            net.train()
        yield net
    finally:
        if not istrain:
            net.eval()


class HybridAT():
    def __init__(self, encoder, projector, classifier, train_loss_fn, val_loss_fn, adv_rate, args=None, device=None, test_mode=False):
        super().__init__()

        self.args = args
        self.save_hyperparameters()
        self.num_classes = self.args.num_classes
        self.epochs = self.args.epochs
        self.train_loss_fn = train_loss_fn
        self.val_loss_fn = val_loss_fn
        self.adv_rate = adv_rate
        
        self.nb_iters = self.args.nb_iters
        self.eps_iter = self.args.eps_iter
        self.ce_loss_fn = torch.nn.CrossEntropyLoss()
        self.loss_scale = 1.0 # self.args.loss_scale
        self.epoch_t = self.args.epoch_t 
        self.test_mode = self.args.test_mode

        if not self.test_mode:
            self.mylogger =  get_logger(logpath=f"{self.args.save_folder}/logs.log", displaying=False)
            print(f"logging file: {args.save_folder}/logs.log")
            self.train_logger = get_logger(logpath=f"{self.args.save_folder}/train.log", displaying=False)
            print(f"training logging file: {args.save_folder}/train.log")
        else:
            print(f"logging file: {args.result_path}")
            self.mylogger = get_logger(logpath=args.result_path, displaying=False)


        self.device = device

        self.encoder = encoder.to(self.device)
        self.projector = projector.to(self.device)
        self.classifier = classifier.to(self.device)

        # training set 
        import torchmetrics
        self.train_natural_acc = torchmetrics.Accuracy(
            task="multiclass", num_classes=self.num_classes, top_k=1
            ).to(self.device)
        # training set 
        self.train_robust_acc = torchmetrics.Accuracy(
            task="multiclass", num_classes=self.num_classes, top_k=1
            ).to(self.device)

        # validation set 
        self.val_natural_acc = torchmetrics.Accuracy(
            task="multiclass", num_classes=self.num_classes, top_k=1
            ).to(self.device)
        # validation set 
        self.val_robust_acc = torchmetrics.Accuracy(
            task="multiclass", num_classes=self.num_classes, top_k=1
            ).to(self.device)

        # Setup metrics
        metrics = MetricCollection({
            'natural_accuracy': Accuracy(task="multiclass", num_classes=self.num_classes),
            'robust_accuracy': Accuracy(task="multiclass", num_classes=self.num_classes),
            'loss': MeanMetric(),
            'loss_cl': MeanMetric(),
            'loss_dr': MeanMetric(),
            'loss_jac': MeanMetric()
        })

        self.train_metrics = metrics.clone(prefix='train_')
        self.val_metrics = metrics.clone(prefix='val_')
        self.test_metrics = metrics.clone(prefix='test_')

        # Move metrics to device
        self.train_metrics.to(device)
        self.val_metrics.to(device)
        self.test_metrics.to(device)


    @torch.enable_grad()
    @torch.inference_mode(False)

    def f(self, x):
        return self.projector(self.encoder(x))
    
    def f_stopBN(self, x):
        self.disable_batchnorm_running(self.encoder)
        self.disable_batchnorm_running(self.projector)
        y = self.f(x)
        self.enable_batchnorm_running(self.encoder)
        self.enable_batchnorm_running(self.projector)
        return y
    
    def jvp_fn(self, x, tangent):
        def f_wrapper(x):
            return self.f(x)
        self.disable_batchnorm_running(self.encoder)
        self.disable_batchnorm_running(self.projector)
        y, jv = jvp(f_wrapper, (x,), (tangent,))
        assert torch.allclose(y, f_wrapper(x))
        self.enable_batchnorm_running(self.encoder)
        self.enable_batchnorm_running(self.projector)
        return jv

    def run_jvp(self, x, S):
        # S: [num_vectors, *x.shape]
        def single_jvp(tangent):
            return self.jvp_fn(x, tangent)
        
        with torch.cuda.amp.autocast():
            Jv = vmap(single_jvp)(S)
        
        return Jv # return J·S
    
    def jac_reg_r(self, batch_x, num_proj=1):
        batch_size, channels, height, width = batch_x.shape
        with torch.no_grad():  
            S_proj = torch.randn(num_proj, batch_size, channels, height, width,
                            device=batch_x.device, dtype=torch.float32)
            # Merge CHW dimensions
            flat_vectors = S_proj.view(num_proj, batch_size, -1)  
            norms = torch.norm(flat_vectors, p=2, dim=2, keepdim=True)
            flat_vectors = flat_vectors / (norms + 1e-8)
            # Check normalized norms
            check_norms = torch.norm(flat_vectors, p=2, dim=2)

            # Ensure norms are close to 1
            assert torch.allclose(check_norms, torch.ones_like(check_norms), rtol=1e-3, atol=1e-3), \
                "Vectors are not properly normalized to unit length"
            
            S_proj = flat_vectors.view_as(S_proj)  

        with torch.cuda.amp.autocast(enabled=False):
            jv = self.run_jvp(batch_x, S_proj) 
            jv = jv.permute(1, 0, 2)  
            # Calculate the frobenius norm of jv
            jv_reshaped = jv.reshape(batch_size, -1) 
            frob_norm_squared = torch.sum(jv_reshaped ** 2, dim=1)  
            C = torch.tensor(channels * height * width / num_proj, 
                            dtype=torch.float32, 
                            device=batch_x.device)
            J2 = C * frob_norm_squared.mean()
            R = 0.5 * J2
        return R
    
    def disable_batchnorm_running(self, model):
        for module in model.modules():
            if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
                module.eval()
    
    def enable_batchnorm_running(self, model):
        for module in model.modules():
            if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
                module.train()

    def jac_reg_sel(self, in_x, out_x, sel = 1):
        if sel == 1: 
            reg = JacobianReg()
            Jac_reg = reg(in_x, out_x)
        elif sel == 2:
            Jac_reg = self.jac_reg_r(in_x)

        return Jac_reg

    def compute_top_k_singular_vectors(self, batch_x, num_proj=10, k=5):
        batch_size, channels, height, width = batch_x.shape
        device = batch_x.device
        num_proj = 10  # set for random svd
        S_proj = torch.randn(num_proj, batch_size, channels, height, width, device=device)
        
        jv = self.run_jvp(batch_x, S_proj)  
        jv = jv.permute(1, 0, 2)  # [batch_size, num_proj, 128]

        U, S, V = torch.svd(jv.float())

        # top k singular vectors
        U_k = U[:, :, :k]  
        S_k = S[:, :k]     
        V_k = V[:, :, :k] 
        S_proj = S_proj.permute(1, 0, 2, 3, 4).reshape(batch_size, num_proj, -1)

        # result 
        result = torch.bmm(U_k.permute(0, 2, 1), S_proj)

        # reconstructed_data 
        reconstructed_data = result.view(batch_size, k, channels, height, width)

        return reconstructed_data 
    
    def perturb(self, target: torch.Tensor, batch_x_start: torch.Tensor):
        for p in self.encoder.parameters():
            p.requires_grad = False
        for p in self.projector.parameters():
            p.requires_grad = False
        for p in self.classifier.parameters():
            p.requires_grad = False
            
        batch_x = batch_x_start.detach().clone().requires_grad_(True)

        batch_size, channels, height, width = batch_x.shape
        
        assert isinstance(self.args.nb_tans, int) and self.args.nb_tans > 0, \
                f"nb_tans must be a positive integer greater than 0, current value: {self.args.nb_tans}"
        tan_interval = math.ceil(self.nb_iters / self.args.nb_tans)
        for idx_iter in range(self.nb_iters):
            if idx_iter % tan_interval == 0:
                tangent_a = self.compute_top_k_singular_vectors(batch_x)
            batch_x = self._iter(target, batch_x, tangent_a)
            batch_x = batch_x.requires_grad_(True)
        
        for p in self.encoder.parameters():
            p.requires_grad = True
        for p in self.projector.parameters():
            p.requires_grad = True
        for p in self.classifier.parameters():
            p.requires_grad = True

        return batch_x
    
    def _iter(self, target=None, batch_x=None, tangent_a=None):
        batch_size, k, channels, height, width = tangent_a.shape

        tangent_a_reshaped = tangent_a.reshape(batch_size * k, -1) 
        norm = torch.norm(tangent_a_reshaped, dim=1, keepdim=True) + 1e-8
        norm = norm.reshape(batch_size, k, 1, 1, 1)
        tangent_a_normalized = tangent_a / norm


        batch_x.grad = None
        # calculate loss
        loss = self.ce_loss_fn(self.classifier(self.encoder(batch_x)), target)
        loss.backward()
        # acquire gradient
        grad = batch_x.grad
        grad_expanded = grad.unsqueeze(1)

        dot_product = torch.einsum('bkchw,bkchw->bk', grad_expanded, tangent_a_normalized)

       
        projection = dot_product[:, :, None, None, None] * tangent_a_normalized
        projection_summed = projection.sum(dim=1)

        batch_x_new = batch_x + self.eps_iter * projection_summed
        batch_x_new = torch.clamp(batch_x_new, 0, 1)

        return batch_x_new.detach()

    
    def training_step(self, epoch, batch):
        imgs, target = batch
        assert isinstance(imgs, (list, tuple)) and len(imgs) == 3
        img0, img1, img2 = imgs[0].to(self.device), imgs[1].to(self.device), imgs[2].to(self.device) 
        target = target.to(self.device)
        N = img0.size(0)

        requires_grad_dict = {
            'img1': False,
            'img2': False,
            'feature1': False,
            'feature2': False
        }
        
        if self.args.jac_reg:
            requires_grad_dict['img1'] = True
            requires_grad_dict['img2'] = True
            
            if self.args.jac_reg_projector:
                requires_grad_dict['feature1'] = True
                requires_grad_dict['feature2'] = True

        img1.requires_grad = requires_grad_dict['img1']
        img2.requires_grad = requires_grad_dict['img2']
    
        feature1, feature2 = self.encoder(img1), self.encoder(img2)
        projection1, projection2 = self.projector(feature1), self.projector(feature2)


            
        if self.args.method == 'scl': 
            loss_cl, loss_cl_dict = self.train_loss_fn(projection1, projection2, target)  
            


        elif self.args.method == 'hybrid': 
            pred1, pred2 = self.classifier(feature1), self.classifier(feature2)
            loss_cl, loss_cl_dict = self.train_loss_fn(projection1, projection2, pred1, pred2, target)  
            


        feature0 = self.encoder(img0)
        pred0 = self.classifier(feature0)
        loss_natural = self.ce_loss_fn(pred0, target)

    
        loss_surrogate = 0.0

        if epoch > self.epoch_t:
            adv_nsamples = int(N * self.adv_rate)
            index = torch.LongTensor(random.sample(range(N), adv_nsamples)).to(self.device)
            img0_ = torch.index_select(img0, 0, index)
            target_ = torch.index_select(target, 0, index)

            img0_adv = self.perturb(target_, img0_)
            features_adv = self.encoder(img0_adv)
            prediction_adv = self.classifier(features_adv)

            loss_robust = self.ce_loss_fn(prediction_adv, target_) 
            loss_surrogate = loss_robust + self.adv_rate * loss_natural

        loss = loss_cl + loss_surrogate

        if self.args.jac_reg:
            R1 = self.jac_reg_sel(img1, projection1, self.args.jac_sel)
            R2 = self.jac_reg_sel(img2, projection2, self.args.jac_sel)
            Jac_reg = self.args.lambda_JR * (R1 + R2)
            loss = loss + Jac_reg 
            
            del R1, R2

    
        with torch.no_grad():
            loss_value = loss.item()
            self.train_metrics['loss'](loss_value) 
            self.train_metrics['loss_cl'](loss_cl.item())
            self.train_metrics['natural_accuracy'](pred0.argmax(dim=1), target)
            instant_natural_acc = (pred0.argmax(dim=1) == target).float().mean()
            instant_robust_acc = (prediction_adv.argmax(dim=1) == target_).float().mean() if epoch > self.epoch_t else torch.tensor(0.0)
        
            
            if epoch > self.epoch_t:
                self.train_metrics['robust_accuracy'](prediction_adv.argmax(dim=1), target_)
                self.train_metrics['loss_dr'](loss_surrogate.item())
           

            if self.args.jac_reg:  
                Jac_reg = Jac_reg.item()
                self.train_metrics['loss_jac'](Jac_reg)


        return loss, instant_natural_acc, instant_robust_acc, Jac_reg

    def validation_step(self, epoch, batch):
        imgs, target = batch
        imgs = imgs.to(self.device)
        target = target.to(self.device)

        # Update metrics
        with torch.no_grad():
            features = self.encoder(imgs)
            predictions = self.classifier(features)
            loss = self.val_loss_fn(predictions, target)
        
            self.val_metrics['loss'](loss.item())
            self.val_metrics['natural_accuracy'](predictions.argmax(dim=1), target)

           
            instant_natural_acc = (predictions.argmax(dim=1) == target).float().mean()

        return loss, instant_natural_acc

    def train(self, train_dataloader, val_dataloader):
        save_config(self.args, self.args.save_folder)
        self.mylogger.info(self.encoder)
        self.mylogger.info(self.projector)
        self.mylogger.info(self.classifier)
        self.mylogger.info(self.args)
        
        best_val_acc = float('-inf')
        best_epoch = 0 

        if self.args.method in ['cl', 'scl', 'hybrid', 'sl', 'ae']:
            parameters = (
                list(self.encoder.parameters()) + 
                list(self.projector.parameters()) + 
                list(self.classifier.parameters())
            )
        else:
            raise ValueError(f"Unknown method: {self.args.method}")

        optimizer = torch.optim.Adam(
            parameters, 
            lr=self.args.lr, 
            weight_decay=self.args.wd
        )

        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.epochs, eta_min=1e-6    
        ) 

        for epoch in range(1, self.epochs+1):
            # Reset metrics at start of epoch
            self.train_metrics.reset()
            self.val_metrics.reset()

            # Training phase
            self.encoder.train()
            pbar = tqdm(train_dataloader, total=len(train_dataloader), leave=True, dynamic_ncols=True)
            
            for batch_idx, batch in enumerate(pbar):
                pbar.set_description(f"Train {self.args.method}: epoch [{epoch}/{self.epochs}] [{batch_idx}/{len(train_dataloader)}]")
                
                optimizer.zero_grad()
                train_loss, train_natural_acc, train_robust_acc, Jac_reg = self.training_step(epoch, batch)
                train_loss.backward()
                optimizer.step()   

               
                pbar.set_postfix({"train_loss": train_loss.item(), 
                    "train_natural_acc": train_natural_acc.item(), 
                    "train_robust_acc": train_robust_acc.item(),
                    "Jac_reg": Jac_reg}  
                    )
                pbar.update(1)
            
            lr_scheduler.step()
            train_metrics = self.train_metrics.compute()  # log the average of the epoch


            self.encoder.eval()
            total_val_samples = len(val_dataloader.dataset)
            batch_size = val_dataloader.batch_size
            num_batches = len(val_dataloader)
            self.mylogger.info(f"Validation set size: {total_val_samples}, "
                            f"batch size: {batch_size}, "
                            f"total batches: {num_batches}")
            pbar = tqdm(val_dataloader, total=len(val_dataloader), leave=True, dynamic_ncols=True)

            for batch_idx, batch in enumerate(pbar):
                pbar.set_description(f"Val: epoch [{epoch}/{self.epochs}] [{batch_idx}/{len(val_dataloader)}]")

                val_loss, val_natural_acc = self.validation_step(epoch, batch)
                
                pbar.set_postfix({
                    "val_loss": f"{val_loss.item():.4f}",
                    "val_acc": f"{val_natural_acc:.4f}"
                })
                pbar.update(1)

            # End of validation epoch
            val_metrics = self.val_metrics.compute()
            self.train_logger.info(
                f"Epoch {epoch} Train metrics: {dict2str(train_metrics)}"
                f"Detailed validation metrics: {dict2str(val_metrics)}"
            )
            self.mylogger.info(
                f"Epoch {epoch} |\n"
                f"Train: {dict2str(train_metrics)} |\n"
                f"Val: {dict2str(val_metrics)} |\n"
                f"LR: {lr_scheduler.get_last_lr()[0]:.6f}"
            )

            # Save best model based on validation accuracy
            if best_val_acc < val_metrics['val_natural_accuracy']:
                best_val_acc = val_metrics['val_natural_accuracy']
                best_epoch=epoch
                save_model2({"encoder":self.encoder,"projector":self.projector,"classifier":self.classifier}, 
                        optimizer, self.args, epoch, 
                        os.path.join(self.args.save_folder, f'best_model.pth'))
            
            elif epoch % self.args.save_freq == 0:
                save_file = os.path.join(self.args.save_folder, f'epoch_{epoch:04d}_model.pth')
                save_model2({"encoder":self.encoder, "projector":self.projector, "classifier":self.classifier}, 
                        optimizer, self.args, 
                        epoch, save_file)
                
        save_model2({"encoder":self.encoder,"projector":self.projector,"classifier":self.classifier}, 
                optimizer, self.args, 
                epoch, os.path.join(self.args.save_folder, f'last_model.pth'))

        self.mylogger.info(f"Training finished. Best acc: {format(best_val_acc, '.4f')} at epoch {best_epoch}")
    
    def test(self, test_dataloader):
        # Reset test metrics
        self.test_metrics.reset()

        pbar = tqdm(test_dataloader, total=len(test_dataloader), leave=True, dynamic_ncols=True)
        
        total_test_samples = len(test_dataloader.dataset)
        batch_size = test_dataloader.batch_size
        num_batches = len(test_dataloader)
        self.mylogger.info(f"Test set size: {total_test_samples}, "
                        f"batch size: {batch_size}, "
                        f"total batches: {num_batches}")

        for batch_idx, batch in enumerate(pbar):
            pbar.set_description(f"Test: [{batch_idx}/{len(test_dataloader)}]")
            
            imgs, target = batch  
            test_loss, test_natural_acc = self.test_step(imgs, target)
            
            pbar.set_postfix({
                "batch": f"{batch_idx}/{num_batches}",
                "samples": f"{(batch_idx+1)*batch_size}/{total_test_samples}",
                "test_loss": f"{test_loss.item():.4f}",
                "test_acc": f"{test_natural_acc:.4f}"
            })
            pbar.update(1)

        test_metrics = self.test_metrics.compute()
        test_avg = {
            "test_avg_loss": test_metrics['test_loss'],
            "test_avg_natural_acc": test_metrics['test_natural_accuracy']
        }

        self.mylogger.info(f"Test average metrics: {dict2str(test_avg)}")
        self.mylogger.info(f"Detailed test metrics: {dict2str(test_metrics)}")

    def test_step(self, imgs, target):
        imgs = imgs.to(self.device)
        target = target.to(self.device)
        
        # Forward pass
        with torch.no_grad():
            features = self.encoder(imgs)
            predictions = self.classifier(features)
            target_ = torch.nn.functional.one_hot(target, num_classes=self.num_classes).float().to(self.device)
            loss = self.val_loss_fn(predictions, target_)        

            self.test_metrics['test_loss'](loss.item())
            self.test_metrics['test_natural_accuracy'](predictions.argmax(dim=1), target)
            
            val_natural_acc = (predictions.argmax(dim=1) == target).float().mean()

        return loss, val_natural_acc

    
    def save_hyperparameters(self):
        pass