// errkit implements all errors generated by nuclei and includes error definations
// specific to nuclei , error classification (like network,logic) etc
package errkit

import (
	"encoding/json"
	"errors"
	"fmt"
	"log/slog"
	"runtime"
	"strconv"
	"strings"
	"time"

	"github.com/projectdiscovery/utils/env"
)

const (
	// DelimArrow is delim used by projectdiscovery/utils to join errors
	DelimArrow = "<-"
	// DelimArrowSerialized
	DelimArrowSerialized = "\u003c-"
	// DelimSemiColon is standard delim popularly used to join errors
	DelimSemiColon = "; "
	// DelimMultiLine is delim used to join errors in multiline format
	DelimMultiLine = "\n -  "
	// MultiLinePrefix is the prefix used for multiline errors
	MultiLineErrPrefix = "the following errors occurred:"
	// Space is the identifier used for indentation
	Space = " "
)

var (
	// MaxErrorDepth is the maximum depth of errors to be unwrapped or maintained
	// all errors beyond this depth will be ignored
	MaxErrorDepth = env.GetEnvOrDefault("MAX_ERROR_DEPTH", 3)
	// FieldSeperator
	ErrFieldSeparator = env.GetEnvOrDefault("ERR_FIELD_SEPERATOR", Space)
	// ErrChainSeperator
	ErrChainSeperator = env.GetEnvOrDefault("ERR_CHAIN_SEPERATOR", DelimSemiColon)
	// EnableTimestamp controls whether error timestamps are included
	EnableTimestamp = env.GetEnvOrDefault("ENABLE_ERR_TIMESTAMP", false)
	// EnableTrace controls whether error stack traces are included
	EnableTrace = env.GetEnvOrDefault("ERRKIT_ENABLE_TRACE", false)
)

// ErrorX is a custom error type that can handle all known types of errors
// wrapping and joining strategies including custom ones and it supports error class
// which can be shown to client/users in more meaningful way
type ErrorX struct {
	kind   ErrKind
	record *slog.Record
	source *slog.Source
	errs   []error
}

func (e *ErrorX) init(skipStack ...int) {
	// initializes if necessary
	if e.record == nil {
		e.record = &slog.Record{}
		if EnableTimestamp {
			e.record.Time = time.Now()
		}
		if EnableTrace {
			// get fn name
			var pcs [1]uintptr
			// skip [runtime.Callers, ErrorX.init, parent]
			skip := 3
			if len(skipStack) > 0 {
				skip = skipStack[0]
			}
			runtime.Callers(skip, pcs[:])
			pc := pcs[0]
			fs := runtime.CallersFrames([]uintptr{pc})
			f, _ := fs.Next()
			e.source = &slog.Source{
				Function: f.Function,
				File:     f.File,
				Line:     f.Line,
			}
		}
	}
}

// append is internal method to append given
// error to error slice , it removes duplicates
// earlier it used map which causes more allocations that necessary
func (e *ErrorX) append(errs ...error) {
	for _, nerr := range errs {
		found := false
	new:
		for _, oerr := range e.errs {
			if oerr.Error() == nerr.Error() {
				found = true
				break new
			}
		}
		if !found {
			e.errs = append(e.errs, nerr)
		}
	}
}

func (e ErrorX) MarshalJSON() ([]byte, error) {
	tmp := []string{}
	for _, err := range e.errs {
		tmp = append(tmp, err.Error())
	}
	if e.kind == nil {
		e.kind = ErrKindUnknown
	}
	m := map[string]interface{}{
		"kind":   e.kind.String(),
		"errors": tmp,
	}
	if e.record != nil && e.record.NumAttrs() > 0 {
		m["attrs"] = slog.GroupValue(e.Attrs()...)
	}
	if e.source != nil {
		m["source"] = e.source
	}
	return json.Marshal(m)
}

// Errors returns all errors parsed by the error
func (e *ErrorX) Errors() []error {
	return e.errs
}

// Attrs returns all attributes associated with the error
func (e *ErrorX) Attrs() []slog.Attr {
	if e.record == nil || e.record.NumAttrs() == 0 {
		return nil
	}
	values := []slog.Attr{}
	e.record.Attrs(func(a slog.Attr) bool {
		values = append(values, a)
		return true
	})
	return values
}

// Build returns the object as error interface
func (e *ErrorX) Build() error {
	return e
}

// Unwrap returns the underlying error
func (e *ErrorX) Unwrap() []error {
	return e.errs
}

// Is checks if current error contains given error
func (e *ErrorX) Is(err error) bool {
	x := &ErrorX{}
	x.init()
	parseError(x, err)
	// even one submatch is enough
	for _, orig := range e.errs {
		for _, match := range x.errs {
			if errors.Is(orig, match) {
				return true
			}
		}
	}
	return false
}

// Error returns the error string
func (e *ErrorX) Error() string {
	var sb strings.Builder
	sb.WriteString("cause=")
	sb.WriteString(strconv.Quote(e.errs[0].Error()))
	if e.record != nil && e.record.NumAttrs() > 0 {
		values := []string{}
		e.record.Attrs(func(a slog.Attr) bool {
			values = append(values, a.String())
			return true
		})
		sb.WriteString(Space)
		sb.WriteString(strings.Join(values, " "))
	}
	if len(e.errs) > 1 {
		chain := []string{}
		for _, value := range e.errs[1:] {
			chain = append(chain, strings.TrimSpace(value.Error()))
		}
		sb.WriteString(Space)
		sb.WriteString("chain=" + strconv.Quote(strings.Join(chain, ErrChainSeperator)))
	}
	return sb.String()
}

// Cause return the original error that caused this without any wrapping
func (e *ErrorX) Cause() error {
	if len(e.errs) > 0 {
		return e.errs[0]
	}
	return nil
}

// Kind returns the errorkind associated with this error
// if any
func (e *ErrorX) Kind() ErrKind {
	if e.kind == nil || e.kind.String() == "" {
		e.kind = ErrKindUnknown
	}
	return e.kind
}

// FromError parses a given error to understand the error class
// and optionally adds given message for more info
func FromError(err error) *ErrorX {
	if err == nil {
		return nil
	}
	nucleiErr := &ErrorX{}
	nucleiErr.init()
	parseError(nucleiErr, err)
	return nucleiErr
}

// New creates a new error with the given message and slog attributes
//
// Example:
//
//	this is correct (√)
//	errkit.New("connection failed", "address", host, "port", port)
//
//	this is also correct (√) 
//	errkit.New("timeout occurred")
//
//	this is not recommended (x) - use Newf instead
//	errkit.New("error on host %s", host)
func New(msg string, args ...interface{}) *ErrorX {
	e := &ErrorX{}
	e.init()
	if len(args) > 0 {
		e.record.Add(args...)
	}
	e.append(errors.New(msg))
	return e
}

// Newf creates a new error with a formatted message
//
// Example:
//
//	errkit.Newf("connection failed on %s:%d", host, port)
func Newf(format string, args ...interface{}) *ErrorX {
	e := &ErrorX{}
	e.init()
	msg := fmt.Sprintf(format, args...)
	e.append(errors.New(msg))
	return e
}

// Msg adds a plain message to the error
//
// Example:
//
//	myError.Msg("connection failed")
func (e *ErrorX) Msg(message string) {
	if e == nil {
		return
	}
	e.append(errors.New(message))
}

// Msgf adds a formatted message to the error
//
// Example:
//
//	this is correct (√)
//	myError.Msgf("dial error on %s:%d", host, port)
//
//	this is also correct (√)
//	myError.Msgf("connection failed")
func (e *ErrorX) Msgf(format string, args ...interface{}) {
	if e == nil {
		return
	}
	msg := fmt.Sprintf(format, args...)
	e.append(errors.New(msg))
}

// SetClass sets the class of the error
// if underlying error class was already set, then it is given preference
// when generating final error msg
//
//	Example:
//
//	this is correct (√)
//	myError.SetKind(errkit.ErrKindNetworkPermanent)
func (e *ErrorX) SetKind(kind ErrKind) *ErrorX {
	if e.kind == nil {
		e.kind = kind
	} else {
		e.kind = CombineErrKinds(e.kind, kind)
	}
	return e
}

// ResetKind resets the error class of the error
//
//	Example:
//
//	myError.ResetKind()
func (e *ErrorX) ResetKind() *ErrorX {
	e.kind = nil
	return e
}

// Deprecated: use WithAttr instead
//
// SetAttr sets additional attributes to a given error
// it only adds unique attributes and ignores duplicates
// Note: only key is checked for uniqueness
//
//	Example:
//
//	this is correct (√)
//	myError.SetAttr(slog.String("address",host))
//
//	Recommended replacement:
//	errkit.WithAttr(myError, slog.String("address", host))
func (e *ErrorX) SetAttr(s ...slog.Attr) *ErrorX {
	e.init()
	for _, attr := range s {
		e.record.Add(attr)
	}
	return e
}

// parseError recursively parses all known types of errors
func parseError(to *ErrorX, err error) {
	// guard against panics in external libraries calls
	defer func() {
		if r := recover(); r != nil {
			// Convert panic to error and append it as the last error
			var panicErr error
			switch v := r.(type) {
			case error:
				panicErr = fmt.Errorf("error while unwrapping: %w", v)
			case string:
				panicErr = fmt.Errorf("error while unwrapping: %s", v)
			default:
				panicErr = fmt.Errorf("error while unwrapping: panic: %v", r)
			}
			to.append(panicErr)
		}
	}()

	if err == nil {
		return
	}
	if to == nil {
		to = &ErrorX{}
		to.init(4)
	}
	if len(to.errs) >= MaxErrorDepth {
		return
	}

	switch v := err.(type) {
	case *ErrorX:
		to.append(v.errs...)
		if v.record != nil {
			if to.record == nil {
				to.record = v.record
			} else {
				v.record.Attrs(func(a slog.Attr) bool {
					to.record.Add(a)
					return true
				})
			}
		}
		if to.source == nil && v.source != nil {
			to.source = v.source
		}
		to.kind = CombineErrKinds(to.kind, v.kind)
	case JoinedError:
		foundAny := false
		for _, e := range v.Unwrap() {
			to.append(e)
			foundAny = true
		}
		if !foundAny {
			parseError(to, errors.New(err.Error()))
		}
	case WrappedError:
		if v.Unwrap() != nil {
			parseError(to, v.Unwrap())
		} else {
			parseError(to, errors.New(err.Error()))
		}
	case CauseError:
		to.append(v.Cause())
		remaining := strings.ReplaceAll(err.Error(), v.Cause().Error(), "")
		parseError(to, errors.New(remaining))
	default:
		errString := err.Error()
		// try assigning to enriched error
		if strings.Contains(errString, DelimArrow) {
			// Split the error by arrow delim
			parts := strings.Split(errString, DelimArrow)
			for i := len(parts) - 1; i >= 0; i-- {
				part := strings.TrimSpace(parts[i])
				parseError(to, errors.New(part))
			}
		} else if strings.Contains(errString, DelimArrowSerialized) {
			// Split the error by arrow delim
			parts := strings.Split(errString, DelimArrowSerialized)
			for i := len(parts) - 1; i >= 0; i-- {
				part := strings.TrimSpace(parts[i])
				parseError(to, errors.New(part))
			}
		} else if strings.Contains(errString, DelimSemiColon) {
			// Split the error by semi-colon delim
			parts := strings.Split(errString, DelimSemiColon)
			for _, part := range parts {
				part = strings.TrimSpace(part)
				parseError(to, errors.New(part))
			}
		} else if strings.Contains(errString, MultiLineErrPrefix) {
			// remove prefix
			msg := strings.ReplaceAll(errString, MultiLineErrPrefix, "")
			parts := strings.Split(msg, DelimMultiLine)
			for _, part := range parts {
				part = strings.TrimSpace(part)
				parseError(to, errors.New(part))
			}
		} else {
			// this cannot be further unwrapped
			to.append(err)
		}
	}
}
